From 084000813642779063a1701b621e86823da5121b Mon Sep 17 00:00:00 2001 From: "Michal W. Tarnowski" Date: Sat, 13 Apr 2019 22:00:06 +0200 Subject: [PATCH 001/279] Non-broadcast Div optimized --- .../internal/optimized/optimized_ops.h | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index e0231785c9c..3d9b61320ac 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -3214,6 +3214,84 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, } } +inline void Div(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const float* input1_data, + const RuntimeShape& input2_shape, const float* input2_data, + const RuntimeShape& output_shape, float* output_data) { + gemmlowp::ScopedProfilingLabel label("Div"); + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + + int i = 0; + const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); +#ifdef USE_NEON + static constexpr int kNewtonSteps = 2; + static const float32x4_t TWO_F32 = vdupq_n_f32(2.f); + const float32x4_t activation_min = vdupq_n_f32(output_activation_min); + const float32x4_t activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + const float32x4_t a10 = vld1q_f32(input1_data + i); + const float32x4_t a11 = vld1q_f32(input1_data + i + 4); + const float32x4_t a12 = vld1q_f32(input1_data + i + 8); + const float32x4_t a13 = vld1q_f32(input1_data + i + 12); + const float32x4_t a20 = vld1q_f32(input2_data + i); + const float32x4_t a21 = vld1q_f32(input2_data + i + 4); + const float32x4_t a22 = vld1q_f32(input2_data + i + 8); + const float32x4_t a23 = vld1q_f32(input2_data + i + 12); + + float32x4_t r0 = vrecpeq_f32(a20); + float32x4_t r1 = vrecpeq_f32(a21); + float32x4_t r2 = vrecpeq_f32(a22); + float32x4_t r3 = vrecpeq_f32(a23); + for (int k = 0; k < kNewtonSteps; ++k) { + r0 = vmulq_f32(r0, vsubq_f32(TWO_F32, vmulq_f32(r0, a20))); + r1 = vmulq_f32(r1, vsubq_f32(TWO_F32, vmulq_f32(r1, a21))); + r2 = vmulq_f32(r2, vsubq_f32(TWO_F32, vmulq_f32(r2, a22))); + r3 = vmulq_f32(r3, vsubq_f32(TWO_F32, vmulq_f32(r3, a23))); + } + + float32x4_t x0 = vmulq_f32(a10, r0); + float32x4_t x1 = vmulq_f32(a11, r1); + float32x4_t x2 = vmulq_f32(a12, r2); + float32x4_t x3 = vmulq_f32(a13, r3); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + const float32x4_t a1 = vld1q_f32(input1_data + i); + const float32x4_t a2 = vld1q_f32(input2_data + i); + + float32x4_t r = vrecpeq_f32(a2); + for (int k = 0; k < kNewtonSteps; ++k) { + r = vmulq_f32(r, vsubq_f32(TWO_F32, vmulq_f32(r, a2))); + } + + float32x4_t x = vmulq_f32(a1, r); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] / input2_data[i], output_activation_min, + output_activation_max); + } +} + // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then From 43a06104a6f3ad4e93099b7f1750948056dd47a7 Mon Sep 17 00:00:00 2001 From: "Michal W. Tarnowski" Date: Sat, 13 Apr 2019 22:31:47 +0200 Subject: [PATCH 002/279] Explicit NEON typenames removed --- .../internal/optimized/optimized_ops.h | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 3d9b61320ac..0d4629c9446 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -3226,23 +3226,23 @@ inline void Div(const ArithmeticParams& params, const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); #ifdef USE_NEON static constexpr int kNewtonSteps = 2; - static const float32x4_t TWO_F32 = vdupq_n_f32(2.f); - const float32x4_t activation_min = vdupq_n_f32(output_activation_min); - const float32x4_t activation_max = vdupq_n_f32(output_activation_max); + static const auto TWO_F32 = vdupq_n_f32(2.f); + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); for (; i <= size - 16; i += 16) { - const float32x4_t a10 = vld1q_f32(input1_data + i); - const float32x4_t a11 = vld1q_f32(input1_data + i + 4); - const float32x4_t a12 = vld1q_f32(input1_data + i + 8); - const float32x4_t a13 = vld1q_f32(input1_data + i + 12); - const float32x4_t a20 = vld1q_f32(input2_data + i); - const float32x4_t a21 = vld1q_f32(input2_data + i + 4); - const float32x4_t a22 = vld1q_f32(input2_data + i + 8); - const float32x4_t a23 = vld1q_f32(input2_data + i + 12); + const auto a10 = vld1q_f32(input1_data + i); + const auto a11 = vld1q_f32(input1_data + i + 4); + const auto a12 = vld1q_f32(input1_data + i + 8); + const auto a13 = vld1q_f32(input1_data + i + 12); + const auto a20 = vld1q_f32(input2_data + i); + const auto a21 = vld1q_f32(input2_data + i + 4); + const auto a22 = vld1q_f32(input2_data + i + 8); + const auto a23 = vld1q_f32(input2_data + i + 12); - float32x4_t r0 = vrecpeq_f32(a20); - float32x4_t r1 = vrecpeq_f32(a21); - float32x4_t r2 = vrecpeq_f32(a22); - float32x4_t r3 = vrecpeq_f32(a23); + auto r0 = vrecpeq_f32(a20); + auto r1 = vrecpeq_f32(a21); + auto r2 = vrecpeq_f32(a22); + auto r3 = vrecpeq_f32(a23); for (int k = 0; k < kNewtonSteps; ++k) { r0 = vmulq_f32(r0, vsubq_f32(TWO_F32, vmulq_f32(r0, a20))); r1 = vmulq_f32(r1, vsubq_f32(TWO_F32, vmulq_f32(r1, a21))); @@ -3250,10 +3250,10 @@ inline void Div(const ArithmeticParams& params, r3 = vmulq_f32(r3, vsubq_f32(TWO_F32, vmulq_f32(r3, a23))); } - float32x4_t x0 = vmulq_f32(a10, r0); - float32x4_t x1 = vmulq_f32(a11, r1); - float32x4_t x2 = vmulq_f32(a12, r2); - float32x4_t x3 = vmulq_f32(a13, r3); + auto x0 = vmulq_f32(a10, r0); + auto x1 = vmulq_f32(a11, r1); + auto x2 = vmulq_f32(a12, r2); + auto x3 = vmulq_f32(a13, r3); x0 = vmaxq_f32(activation_min, x0); x1 = vmaxq_f32(activation_min, x1); x2 = vmaxq_f32(activation_min, x2); @@ -3269,15 +3269,15 @@ inline void Div(const ArithmeticParams& params, vst1q_f32(output_data + i + 12, x3); } for (; i <= size - 4; i += 4) { - const float32x4_t a1 = vld1q_f32(input1_data + i); - const float32x4_t a2 = vld1q_f32(input2_data + i); + const auto a1 = vld1q_f32(input1_data + i); + const auto a2 = vld1q_f32(input2_data + i); - float32x4_t r = vrecpeq_f32(a2); + auto r = vrecpeq_f32(a2); for (int k = 0; k < kNewtonSteps; ++k) { r = vmulq_f32(r, vsubq_f32(TWO_F32, vmulq_f32(r, a2))); } - float32x4_t x = vmulq_f32(a1, r); + auto x = vmulq_f32(a1, r); x = vmaxq_f32(activation_min, x); x = vminq_f32(activation_max, x); From 302358459136384bb23c075b633f4c0e3159df49 Mon Sep 17 00:00:00 2001 From: Clayne Robison Date: Tue, 21 May 2019 16:46:36 -0700 Subject: [PATCH 003/279] Add Dockerfile partials to support Mkl + MPI + Horovod; Remove trailing whitespace from python.partial.Dockerfile --- .../partials/devel-horovod.partial.Dockerfile | 3 ++ .../partials/horovod.partial.Dockerfile | 2 + .../partials/mpi.partial.Dockerfile | 44 ++++++++++++++++++ .../partials/ubuntu/python.partial.Dockerfile | 2 +- tensorflow/tools/dockerfiles/spec.yml | 33 ++++++++++++- .../dockerfiles/tests/build-mkl-horovod.sh | 46 +++++++++++++++++++ .../dockerfiles/tests/import-mkl-horovod.sh | 18 ++++++++ 7 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 tensorflow/tools/dockerfiles/partials/devel-horovod.partial.Dockerfile create mode 100644 tensorflow/tools/dockerfiles/partials/horovod.partial.Dockerfile create mode 100644 tensorflow/tools/dockerfiles/partials/mpi.partial.Dockerfile create mode 100755 tensorflow/tools/dockerfiles/tests/build-mkl-horovod.sh create mode 100755 tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh diff --git a/tensorflow/tools/dockerfiles/partials/devel-horovod.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/devel-horovod.partial.Dockerfile new file mode 100644 index 00000000000..dab42914df3 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/devel-horovod.partial.Dockerfile @@ -0,0 +1,3 @@ +# Check out horovod source code if --build-arg CHECKOUT_HOROVOD_SRC=1 +ARG CHECKOUT_HOROVOD_SRC=0 +RUN test "${CHECKOUT_HOROVOD_SRC}" -eq 1 && git clone --recursive https://github.com/uber/horovod.git /horovod_src || true diff --git a/tensorflow/tools/dockerfiles/partials/horovod.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/horovod.partial.Dockerfile new file mode 100644 index 00000000000..b8b6aab3af2 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/horovod.partial.Dockerfile @@ -0,0 +1,2 @@ +# Install Horovod +RUN ${PIP} install --no-cache-dir horovod diff --git a/tensorflow/tools/dockerfiles/partials/mpi.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/mpi.partial.Dockerfile new file mode 100644 index 00000000000..5c0de90549f --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/mpi.partial.Dockerfile @@ -0,0 +1,44 @@ +# install libnuma, openssh, wget +RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* || \ + yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all || \ + echo "Unsupported Linux distribution. Aborting!" && exit 1 + +# Install Open MPI +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ + tar zxf openmpi-4.0.0.tar.gz && \ + cd openmpi-4.0.0 && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile index 6af47319538..602bdbf5606 100644 --- a/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile @@ -15,4 +15,4 @@ RUN ${PIP} --no-cache-dir install --upgrade \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml index 6fddfe000c6..ea5a70222f5 100644 --- a/tensorflow/tools/dockerfiles/spec.yml +++ b/tensorflow/tools/dockerfiles/spec.yml @@ -1,5 +1,5 @@ header: | - # Copyright 2018 The TensorFlow Authors. All Rights Reserved. + # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -83,6 +83,21 @@ slice_sets: - ubuntu/python - tensorflow - shell + - add_to_name: "-horovod" + dockerfile_exclusive_name: "horovod" + dockerfile_subdirectory: "mkl" + partials: + - ubuntu/version + - ubuntu/cpu + - ubuntu/python + - tensorflow + - mpi + - horovod + - shell + tests: + - import-mkl-horovod.sh + args: + - TF_PACKAGE=intel-tensorflow - add_to_name: "-gpu" dockerfile_exclusive_name: "gpu" args: @@ -110,6 +125,22 @@ slice_sets: - build-cpu.sh args: - CHECKOUT_TF_SRC=1 + - add_to_name: "devel-horovod" + dockerfile_exclusive_name: "devel-horovod" + dockerfile_subdirectory: "mkl" + partials: + - ubuntu/version + - ubuntu/devel-cpu + - ubuntu/python + - ubuntu/bazel + - mpi + - devel-horovod + - shell + tests: + - build-mkl-horovod.sh + args: + - CHECKOUT_TF_SRC=1 + - CHECKOUT_HOROVOD_SRC=1 - add_to_name: "devel-gpu" dockerfile_exclusive_name: "devel-gpu" partials: diff --git a/tensorflow/tools/dockerfiles/tests/build-mkl-horovod.sh b/tensorflow/tools/dockerfiles/tests/build-mkl-horovod.sh new file mode 100755 index 00000000000..62c2ffbc471 --- /dev/null +++ b/tensorflow/tools/dockerfiles/tests/build-mkl-horovod.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + + +# Download and build TensorFlow. +set -euxo pipefail +git clone --branch=master --depth=1 https://github.com/tensorflow/tensorflow.git /tensorflow +cd /tensorflow + +ln -s $(which ${PYTHON}) /usr/local/bin/python + +# Build TensorFlow with support for Intel(R) MKL-DNN +yes "" | ${PYTHON} configure.py && \ + bazel build -c opt --config=mkl --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \ + tensorflow/tools/pip_package:build_pip_package && \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \ + pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \ + rm -rf /tmp/pip && \ + rm -rf /root/.cache + + +# download and build Horovod +git clone --recursive https://github.com/uber/horovod.git +cd horovod +# export environment +export HOROVOD_WITHOUT_PYTORCH=1 +export HOROVOD_WITH_TENSORFLOW=1 +python setup.py sdist +pip --no-cache-dir install --upgrade sdist/horovod*.tar.gz && \ + rm -rf sdist && \ + rm -rf /root/.cache diff --git a/tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh b/tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh new file mode 100755 index 00000000000..b1cae48c6ee --- /dev/null +++ b/tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +python -c 'from tensorflow.python import pywrap_tensorflow; pywrap_tensorflow.IsMklEnabled() or exit(1); import horovod.tensorflow as hvd' From 953a5deffab5b9b4fd652cf561d479d2066eb18e Mon Sep 17 00:00:00 2001 From: Clayne Robison Date: Thu, 23 May 2019 17:59:44 -0700 Subject: [PATCH 004/279] Adding the generated horovod Dockerfiles --- .../mkl/devel-horovod-jupyter.Dockerfile | 177 ++++++++++++++++++ .../dockerfiles/mkl/devel-horovod.Dockerfile | 158 ++++++++++++++++ .../mkl/horovod-jupyter.Dockerfile | 124 ++++++++++++ .../dockerfiles/mkl/horovod.Dockerfile | 105 +++++++++++ 4 files changed, 564 insertions(+) create mode 100644 tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod-jupyter.Dockerfile create mode 100644 tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod.Dockerfile create mode 100644 tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod-jupyter.Dockerfile create mode 100644 tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod.Dockerfile diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod-jupyter.Dockerfile new file mode 100644 index 00000000000..e604832fa63 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod-jupyter.Dockerfile @@ -0,0 +1,177 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +ARG UBUNTU_VERSION=18.04 + +FROM ubuntu:${UBUNTU_VERSION} AS base + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libzmq3-dev \ + pkg-config \ + rsync \ + software-properties-common \ + sudo \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ENV CI_BUILD_PYTHON python + +# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version +ARG CACHE_STOP=1 +# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1 +ARG CHECKOUT_TF_SRC=0 +RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true + +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} --no-cache-dir install --upgrade \ + pip \ + setuptools + +# Some TF tools expect a "python" binary +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + wget \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + virtualenv \ + swig + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + keras_applications \ + keras_preprocessing \ + matplotlib \ + mock \ + numpy \ + scipy \ + sklearn \ + pandas \ + portpicker \ + && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ + enum34 + +# Install bazel +ARG BAZEL_VERSION=0.24.1 +RUN mkdir /bazel && \ + wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ + wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ + chmod +x /bazel/installer.sh && \ + /bazel/installer.sh && \ + rm -f /bazel/installer.sh + +# install libnuma, openssh, wget +RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* || \ + yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all || \ + echo "Unsupported Linux distribution. Aborting!" && exit 1 + +# Install Open MPI +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ + tar zxf openmpi-4.0.0.tar.gz && \ + cd openmpi-4.0.0 && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# Check out horovod source code if --build-arg CHECKOUT_HOROVOD_SRC=1 +ARG CHECKOUT_HOROVOD_SRC=0 +RUN test "${CHECKOUT_HOROVOD_SRC}" -eq 1 && git clone --recursive https://github.com/uber/horovod.git /horovod_src || true + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc + +RUN ${PIP} install jupyter matplotlib +RUN ${PIP} install jupyter_http_over_ws +RUN jupyter serverextension enable --py jupyter_http_over_ws + +RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ +RUN mkdir /.local && chmod a+rwx /.local +RUN apt-get install -y --no-install-recommends wget +WORKDIR /tf/tensorflow-tutorials +RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb +RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb +COPY readme-for-jupyter.md README.md +RUN apt-get autoremove -y && apt-get remove -y wget +WORKDIR /tf +EXPOSE 8888 + +RUN ${PYTHON} -m ipykernel.kernelspec + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod.Dockerfile new file mode 100644 index 00000000000..dd9ccc8bf3d --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod.Dockerfile @@ -0,0 +1,158 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +ARG UBUNTU_VERSION=18.04 + +FROM ubuntu:${UBUNTU_VERSION} AS base + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libzmq3-dev \ + pkg-config \ + rsync \ + software-properties-common \ + sudo \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ENV CI_BUILD_PYTHON python + +# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version +ARG CACHE_STOP=1 +# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1 +ARG CHECKOUT_TF_SRC=0 +RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true + +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} --no-cache-dir install --upgrade \ + pip \ + setuptools + +# Some TF tools expect a "python" binary +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + wget \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + virtualenv \ + swig + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + keras_applications \ + keras_preprocessing \ + matplotlib \ + mock \ + numpy \ + scipy \ + sklearn \ + pandas \ + portpicker \ + && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ + enum34 + +# Install bazel +ARG BAZEL_VERSION=0.24.1 +RUN mkdir /bazel && \ + wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ + wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ + chmod +x /bazel/installer.sh && \ + /bazel/installer.sh && \ + rm -f /bazel/installer.sh + +# install libnuma, openssh, wget +RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* || \ + yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all || \ + echo "Unsupported Linux distribution. Aborting!" && exit 1 + +# Install Open MPI +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ + tar zxf openmpi-4.0.0.tar.gz && \ + cd openmpi-4.0.0 && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# Check out horovod source code if --build-arg CHECKOUT_HOROVOD_SRC=1 +ARG CHECKOUT_HOROVOD_SRC=0 +RUN test "${CHECKOUT_HOROVOD_SRC}" -eq 1 && git clone --recursive https://github.com/uber/horovod.git /horovod_src || true + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod-jupyter.Dockerfile new file mode 100644 index 00000000000..49fed163605 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod-jupyter.Dockerfile @@ -0,0 +1,124 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +ARG UBUNTU_VERSION=18.04 + +FROM ubuntu:${UBUNTU_VERSION} as base + +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} --no-cache-dir install --upgrade \ + pip \ + setuptools + +# Some TF tools expect a "python" binary +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python + +# Options: +# tensorflow +# tensorflow-gpu +# tf-nightly +# tf-nightly-gpu +# Set --build-arg TF_PACKAGE_VERSION=1.11.0rc0 to install a specific version. +# Installs the latest version by default. +ARG TF_PACKAGE=tensorflow +ARG TF_PACKAGE_VERSION= +RUN ${PIP} install ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACKAGE_VERSION}} + +# install libnuma, openssh, wget +RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* || \ + yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all || \ + echo "Unsupported Linux distribution. Aborting!" && exit 1 + +# Install Open MPI +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ + tar zxf openmpi-4.0.0.tar.gz && \ + cd openmpi-4.0.0 && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# Install Horovod +RUN ${PIP} install --no-cache-dir horovod + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc + +RUN ${PIP} install jupyter matplotlib +RUN ${PIP} install jupyter_http_over_ws +RUN jupyter serverextension enable --py jupyter_http_over_ws + +RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ +RUN mkdir /.local && chmod a+rwx /.local +RUN apt-get install -y --no-install-recommends wget +WORKDIR /tf/tensorflow-tutorials +RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb +RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb +COPY readme-for-jupyter.md README.md +RUN apt-get autoremove -y && apt-get remove -y wget +WORKDIR /tf +EXPOSE 8888 + +RUN ${PYTHON} -m ipykernel.kernelspec + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod.Dockerfile new file mode 100644 index 00000000000..a8fd040dca7 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod.Dockerfile @@ -0,0 +1,105 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +ARG UBUNTU_VERSION=18.04 + +FROM ubuntu:${UBUNTU_VERSION} as base + +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} --no-cache-dir install --upgrade \ + pip \ + setuptools + +# Some TF tools expect a "python" binary +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python + +# Options: +# tensorflow +# tensorflow-gpu +# tf-nightly +# tf-nightly-gpu +# Set --build-arg TF_PACKAGE_VERSION=1.11.0rc0 to install a specific version. +# Installs the latest version by default. +ARG TF_PACKAGE=tensorflow +ARG TF_PACKAGE_VERSION= +RUN ${PIP} install ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACKAGE_VERSION}} + +# install libnuma, openssh, wget +RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* || \ + yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all || \ + echo "Unsupported Linux distribution. Aborting!" && exit 1 + +# Install Open MPI +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ + tar zxf openmpi-4.0.0.tar.gz && \ + cd openmpi-4.0.0 && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# Install Horovod +RUN ${PIP} install --no-cache-dir horovod + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc From 27e206450bb8017a67bb3f9f629c7655373888a6 Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Wed, 18 Sep 2019 11:31:23 +0800 Subject: [PATCH 005/279] [xla] fix xla build on cuda devices without nccl Some cuda devices, such as Jetson devices, do not support NCCL. Building `@local_config_nccl//:nccl` on such kind of devices will cause problem. --- tensorflow/compiler/xla/service/gpu/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7b871951ed0..3aeae5008f7 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -415,7 +415,7 @@ tf_cuda_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor/cuda:cuda_activation", "//tensorflow/stream_executor/cuda:cuda_gpu_executor", - ] + if_cuda([ + ] + if_nccl([ "@local_config_nccl//:nccl", ]), ) From cd4ae94127ab68e00c6f615b0fabbc3c268fc3f5 Mon Sep 17 00:00:00 2001 From: Clayne Robison Date: Mon, 14 Oct 2019 02:14:27 -0700 Subject: [PATCH 006/279] Moving files to mkl_horovod folders to reflect usage --- .../devel-horovod-jupyter.Dockerfile | 20 ++++++++++------- .../devel-horovod.Dockerfile | 20 ++++++++++------- .../horovod-jupyter.Dockerfile | 22 +++++++++++-------- .../{mkl => mkl_horovod}/horovod.Dockerfile | 22 +++++++++++-------- .../partials/horovod.partial.Dockerfile | 2 -- .../devel-horovod.partial.Dockerfile | 0 .../mkl_horovod/horovod.partial.Dockerfile | 3 +++ .../{ => mkl_horovod}/mpi.partial.Dockerfile | 19 +++++++++------- tensorflow/tools/dockerfiles/spec.yml | 14 ++++++------ 9 files changed, 71 insertions(+), 51 deletions(-) rename tensorflow/tools/dockerfiles/dockerfiles/{mkl => mkl_horovod}/devel-horovod-jupyter.Dockerfile (90%) rename tensorflow/tools/dockerfiles/dockerfiles/{mkl => mkl_horovod}/devel-horovod.Dockerfile (88%) rename tensorflow/tools/dockerfiles/dockerfiles/{mkl => mkl_horovod}/horovod-jupyter.Dockerfile (85%) rename tensorflow/tools/dockerfiles/dockerfiles/{mkl => mkl_horovod}/horovod.Dockerfile (81%) delete mode 100644 tensorflow/tools/dockerfiles/partials/horovod.partial.Dockerfile rename tensorflow/tools/dockerfiles/partials/{ => mkl_horovod}/devel-horovod.partial.Dockerfile (100%) create mode 100644 tensorflow/tools/dockerfiles/partials/mkl_horovod/horovod.partial.Dockerfile rename tensorflow/tools/dockerfiles/partials/{ => mkl_horovod}/mpi.partial.Dockerfile (68%) diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile similarity index 90% rename from tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod-jupyter.Dockerfile rename to tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile index e604832fa63..92118f0ade8 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile @@ -92,6 +92,7 @@ RUN ${PIP} --no-cache-dir install \ scipy \ sklearn \ pandas \ + future \ portpicker \ && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 @@ -106,27 +107,30 @@ RUN mkdir /bazel && \ rm -f /bazel/installer.sh # install libnuma, openssh, wget -RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ libnuma-dev \ openssh-server \ openssh-client \ wget && \ apt-get clean && \ - rm -rf /var/lib/apt/lists/* || \ - yum -y update && yum -y install \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ numactl-devel \ openssh-server \ openssh-clients \ wget && \ - yum clean all || \ - echo "Unsupported Linux distribution. Aborting!" && exit 1 + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) # Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz RUN mkdir /tmp/openmpi && \ cd /tmp/openmpi && \ - wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ - tar zxf openmpi-4.0.0.tar.gz && \ - cd openmpi-4.0.0 && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ ./configure --enable-orterun-prefix-by-default && \ make -j $(nproc) all && \ make install && \ diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile similarity index 88% rename from tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod.Dockerfile rename to tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile index dd9ccc8bf3d..338474678d2 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/mkl/devel-horovod.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile @@ -92,6 +92,7 @@ RUN ${PIP} --no-cache-dir install \ scipy \ sklearn \ pandas \ + future \ portpicker \ && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 @@ -106,27 +107,30 @@ RUN mkdir /bazel && \ rm -f /bazel/installer.sh # install libnuma, openssh, wget -RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ libnuma-dev \ openssh-server \ openssh-client \ wget && \ apt-get clean && \ - rm -rf /var/lib/apt/lists/* || \ - yum -y update && yum -y install \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ numactl-devel \ openssh-server \ openssh-clients \ wget && \ - yum clean all || \ - echo "Unsupported Linux distribution. Aborting!" && exit 1 + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) # Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz RUN mkdir /tmp/openmpi && \ cd /tmp/openmpi && \ - wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ - tar zxf openmpi-4.0.0.tar.gz && \ - cd openmpi-4.0.0 && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ ./configure --enable-orterun-prefix-by-default && \ make -j $(nproc) all && \ make install && \ diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile similarity index 85% rename from tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod-jupyter.Dockerfile rename to tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile index 49fed163605..5ba0fe65500 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile @@ -54,27 +54,30 @@ ARG TF_PACKAGE_VERSION= RUN ${PIP} install ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACKAGE_VERSION}} # install libnuma, openssh, wget -RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ libnuma-dev \ openssh-server \ openssh-client \ wget && \ apt-get clean && \ - rm -rf /var/lib/apt/lists/* || \ - yum -y update && yum -y install \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ numactl-devel \ openssh-server \ openssh-clients \ wget && \ - yum clean all || \ - echo "Unsupported Linux distribution. Aborting!" && exit 1 + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) # Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz RUN mkdir /tmp/openmpi && \ cd /tmp/openmpi && \ - wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ - tar zxf openmpi-4.0.0.tar.gz && \ - cd openmpi-4.0.0 && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ ./configure --enable-orterun-prefix-by-default && \ make -j $(nproc) all && \ make install && \ @@ -99,7 +102,8 @@ RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_confi mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config # Install Horovod -RUN ${PIP} install --no-cache-dir horovod +ARG HOROVOD_VERSION=0.16.4 +RUN ${PIP} install --no-cache-dir horovod==${HOROVOD_VERSION} COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile similarity index 81% rename from tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod.Dockerfile rename to tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile index a8fd040dca7..e08b910a1bb 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/mkl/horovod.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile @@ -54,27 +54,30 @@ ARG TF_PACKAGE_VERSION= RUN ${PIP} install ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACKAGE_VERSION}} # install libnuma, openssh, wget -RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ libnuma-dev \ openssh-server \ openssh-client \ wget && \ apt-get clean && \ - rm -rf /var/lib/apt/lists/* || \ - yum -y update && yum -y install \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ numactl-devel \ openssh-server \ openssh-clients \ wget && \ - yum clean all || \ - echo "Unsupported Linux distribution. Aborting!" && exit 1 + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) # Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz RUN mkdir /tmp/openmpi && \ cd /tmp/openmpi && \ - wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ - tar zxf openmpi-4.0.0.tar.gz && \ - cd openmpi-4.0.0 && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ ./configure --enable-orterun-prefix-by-default && \ make -j $(nproc) all && \ make install && \ @@ -99,7 +102,8 @@ RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_confi mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config # Install Horovod -RUN ${PIP} install --no-cache-dir horovod +ARG HOROVOD_VERSION=0.16.4 +RUN ${PIP} install --no-cache-dir horovod==${HOROVOD_VERSION} COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/partials/horovod.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/horovod.partial.Dockerfile deleted file mode 100644 index b8b6aab3af2..00000000000 --- a/tensorflow/tools/dockerfiles/partials/horovod.partial.Dockerfile +++ /dev/null @@ -1,2 +0,0 @@ -# Install Horovod -RUN ${PIP} install --no-cache-dir horovod diff --git a/tensorflow/tools/dockerfiles/partials/devel-horovod.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/mkl_horovod/devel-horovod.partial.Dockerfile similarity index 100% rename from tensorflow/tools/dockerfiles/partials/devel-horovod.partial.Dockerfile rename to tensorflow/tools/dockerfiles/partials/mkl_horovod/devel-horovod.partial.Dockerfile diff --git a/tensorflow/tools/dockerfiles/partials/mkl_horovod/horovod.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/mkl_horovod/horovod.partial.Dockerfile new file mode 100644 index 00000000000..b2bb20f713d --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/mkl_horovod/horovod.partial.Dockerfile @@ -0,0 +1,3 @@ +# Install Horovod +ARG HOROVOD_VERSION=0.16.4 +RUN ${PIP} install --no-cache-dir horovod==${HOROVOD_VERSION} diff --git a/tensorflow/tools/dockerfiles/partials/mpi.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/mkl_horovod/mpi.partial.Dockerfile similarity index 68% rename from tensorflow/tools/dockerfiles/partials/mpi.partial.Dockerfile rename to tensorflow/tools/dockerfiles/partials/mkl_horovod/mpi.partial.Dockerfile index 5c0de90549f..67055ab244a 100644 --- a/tensorflow/tools/dockerfiles/partials/mpi.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/mkl_horovod/mpi.partial.Dockerfile @@ -1,25 +1,28 @@ # install libnuma, openssh, wget -RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ libnuma-dev \ openssh-server \ openssh-client \ wget && \ apt-get clean && \ - rm -rf /var/lib/apt/lists/* || \ - yum -y update && yum -y install \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ numactl-devel \ openssh-server \ openssh-clients \ wget && \ - yum clean all || \ - echo "Unsupported Linux distribution. Aborting!" && exit 1 + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) # Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz RUN mkdir /tmp/openmpi && \ cd /tmp/openmpi && \ - wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ - tar zxf openmpi-4.0.0.tar.gz && \ - cd openmpi-4.0.0 && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ ./configure --enable-orterun-prefix-by-default && \ make -j $(nproc) all && \ make install && \ diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml index ea5a70222f5..5a64b70bacb 100644 --- a/tensorflow/tools/dockerfiles/spec.yml +++ b/tensorflow/tools/dockerfiles/spec.yml @@ -40,7 +40,7 @@ releases: nightly: tag_specs: - "{nightly}{py}{jupyter}" - - "{ubuntu-devel}{py}" + - "{_TAG_PREFIX}{ubuntu-devel}{py}" # Built per-release and pushed to tensorflow/tensorflow # --arg _TAG_PREFIX= should be set to "1.11" (for example) or "latest". @@ -85,14 +85,14 @@ slice_sets: - shell - add_to_name: "-horovod" dockerfile_exclusive_name: "horovod" - dockerfile_subdirectory: "mkl" + dockerfile_subdirectory: "mkl_horovod" partials: - ubuntu/version - ubuntu/cpu - ubuntu/python - tensorflow - - mpi - - horovod + - mkl_horovod/mpi + - mkl_horovod/horovod - shell tests: - import-mkl-horovod.sh @@ -127,14 +127,14 @@ slice_sets: - CHECKOUT_TF_SRC=1 - add_to_name: "devel-horovod" dockerfile_exclusive_name: "devel-horovod" - dockerfile_subdirectory: "mkl" + dockerfile_subdirectory: "mkl_horovod" partials: - ubuntu/version - ubuntu/devel-cpu - ubuntu/python - ubuntu/bazel - - mpi - - devel-horovod + - mkl_horovod/mpi + - mkl_horovod/devel-horovod - shell tests: - build-mkl-horovod.sh From 1a489f5c03edc60e26b10fd8a3d13ef03711f50f Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 6 Nov 2019 18:02:13 +0000 Subject: [PATCH 007/279] [ROCm] Fix for the broken ROCm CSB. The following commit breaks the --config=rocm build https://github.com/tensorflow/tensorflow/commit/bf9c196f37b9cbb3109b2891aaf9da85bf5f712a The above commit adds support for complex type in the optimizers. Complex types are not supported on the ROCm platform. Support for it needs to be excluded on the ROCm platform, and that is what this "fix" does. --- tensorflow/core/kernels/training_ops.cc | 72 ++++++++++++++----- .../core/kernels/training_ops_gpu.cu.cc | 40 ++++++++--- .../keras/optimizer_v2/adadelta_test.py | 3 +- .../python/keras/optimizer_v2/adagrad_test.py | 3 +- .../python/keras/optimizer_v2/rmsprop_test.py | 3 +- 5 files changed, 90 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 3490dc1ee80..467087b7864 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -670,7 +670,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -682,7 +684,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -849,7 +853,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -861,7 +867,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -1340,7 +1348,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -1352,7 +1362,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -1456,7 +1468,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -1468,7 +1482,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -2957,7 +2973,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -2969,7 +2987,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -3195,7 +3215,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -3207,7 +3229,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -3337,7 +3361,9 @@ DECLARE_GPU_SPEC(float, int32); DECLARE_GPU_SPEC(float, int64); DECLARE_GPU_SPEC(double, int32); DECLARE_GPU_SPEC(double, int64); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64, int32); DECLARE_GPU_SPEC(complex64, int64); @@ -3355,7 +3381,9 @@ DECLARE_GPU_SPEC(complex128, int64); REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); REGISTER_GPU_KERNELS(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_GPU_KERNELS(complex64); REGISTER_GPU_KERNELS(complex128); @@ -3622,7 +3650,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -3634,7 +3664,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -4151,7 +4183,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -4163,7 +4197,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 0995b31e734..8b7f5dc2e40 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -524,7 +524,9 @@ struct ApplyPowerSign { template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; @@ -534,7 +536,9 @@ template struct functor::ApplyGradientDescent; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; @@ -544,7 +548,9 @@ template struct functor::ApplyAdagrad; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; @@ -554,7 +560,9 @@ template struct functor::ApplyAdagradV2; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; @@ -572,7 +580,9 @@ template struct functor::ApplyFtrlV2; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; @@ -582,7 +592,9 @@ template struct functor::ApplyMomentum; template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; @@ -597,7 +609,9 @@ template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; @@ -609,7 +623,9 @@ template struct functor::SparseApplyKerasMomentum; template struct functor::ApplyAdam; template struct functor::ApplyAdam; template struct functor::ApplyAdam; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdam; template struct functor::ApplyAdam; @@ -627,7 +643,9 @@ template struct functor::ApplyAdaMax; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; @@ -637,7 +655,9 @@ template struct functor::ApplyRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py index 4dad9198b85..5ff9a563f49 100644 --- a/tensorflow/python/keras/optimizer_v2/adadelta_test.py +++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py @@ -35,7 +35,8 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows": +if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" \ + and not test.is_built_with_rocm(): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py index b0b661da8f7..c8e49a003d8 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad_test.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py @@ -38,7 +38,8 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows": +if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" \ + and not test.is_built_with_rocm(): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py index 0482b6f00b7..1a525004c37 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py @@ -41,7 +41,8 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows": +if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" \ + and not test.is_built_with_rocm(): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] _TEST_PARAM_VALUES = [ From b1787be9984c04b2761600d215b77e9f3069749c Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Fri, 8 Nov 2019 18:02:46 +0000 Subject: [PATCH 008/279] [ROCm] Fix for the broken ROCm CSB. The following commit breaks the --config=rocm build https://github.com/tensorflow/tensorflow/commit/f72695e1717a545bfc898b7230cc195bf28b43df The above commit adds a couple of subtests that require support for the `StatefulUnirformFullInt` Op on the GPU. Currently ROCm does not support that Op on the GPU, which leads to those subtests failing. The "fix" is to skip those subtests on the ROCm platform. --- .../python/keras/layers/image_preprocessing_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tensorflow/python/keras/layers/image_preprocessing_test.py b/tensorflow/python/keras/layers/image_preprocessing_test.py index d33acbf0de7..c25435a3d28 100644 --- a/tensorflow/python/keras/layers/image_preprocessing_test.py +++ b/tensorflow/python/keras/layers/image_preprocessing_test.py @@ -187,6 +187,11 @@ class RandomCropTest(keras_parameterized.TestCase): self._run_test(expected_height, expected_width) def test_training_with_mock(self): + if test.is_built_with_rocm(): + # TODO(rocm): + # re-enable this test once ROCm adds support for + # the StatefulUniformFullInt Op (on the GPU) + self.skipTest("Feature not supported on ROCm") np.random.seed(1337) height, width = 3, 4 height_offset = np.random.randint(low=0, high=3) @@ -207,6 +212,11 @@ class RandomCropTest(keras_parameterized.TestCase): ('random_crop_4_by_6', 4, 6), ('random_crop_3_by_2', 3, 2)) def test_random_crop_output_shape(self, expected_height, expected_width): + if test.is_built_with_rocm(): + # TODO(rocm): + # re-enable this test once ROCm adds support for + # the StatefulUniformFullInt Op (on the GPU) + self.skipTest("Feature not supported on ROCm") with CustomObjectScope({'RandomCrop': image_preprocessing.RandomCrop}): self._run_test(expected_height, expected_width) From 75054f706c6cc8249fb786e5de3bf95045f96e16 Mon Sep 17 00:00:00 2001 From: Douman Date: Fri, 8 Nov 2019 20:33:23 +0100 Subject: [PATCH 009/279] TfLite GL delegate is built with visibility hidden --- tensorflow/lite/delegates/gpu/BUILD | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 0fb5dc53488..26716cc2329 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -106,7 +106,7 @@ objc_library( ], ) -# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtensorflowlite_gpu_gl.so +# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt --linkopt -s --strip always :libtensorflowlite_gpu_gl.so cc_binary( name = "libtensorflowlite_gpu_gl.so", linkopts = [ @@ -116,7 +116,10 @@ cc_binary( "-lEGL", "-lGLESv3", ], - "//conditions:default": [], + "//tensorflow:windows": [], + "//conditions:default": [ + "-fvisibility=hidden", + ], }), linkshared = 1, linkstatic = 1, @@ -127,7 +130,7 @@ cc_binary( deps = [":gl_delegate"], ) -# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtensorflowlite_gpu_delegate.so +# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt --linkopt -s --strip always :libtensorflowlite_gpu_delegate.so cc_binary( name = "libtensorflowlite_gpu_delegate.so", linkopts = [ @@ -137,7 +140,10 @@ cc_binary( "-lEGL", "-lGLESv3", ], - "//conditions:default": [], + "//tensorflow:windows": [], + "//conditions:default": [ + "-fvisibility=hidden", + ], }), linkshared = 1, linkstatic = 1, From b37407fb18f89bb2297c27d95687aab748aa4f45 Mon Sep 17 00:00:00 2001 From: Douman Date: Mon, 11 Nov 2019 19:36:26 +0100 Subject: [PATCH 010/279] Add visibility default on android select According to docs, default kicks in only when nothing else matches --- tensorflow/lite/delegates/gpu/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 26716cc2329..ecb4e287611 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -115,6 +115,7 @@ cc_binary( "//tensorflow:android": [ "-lEGL", "-lGLESv3", + "-fvisibility=hidden", ], "//tensorflow:windows": [], "//conditions:default": [ @@ -139,6 +140,7 @@ cc_binary( "//tensorflow:android": [ "-lEGL", "-lGLESv3", + "-fvisibility=hidden", ], "//tensorflow:windows": [], "//conditions:default": [ From 4382fbe39f57b6a84b2f5eaa4508255f607eb431 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 13 Nov 2019 20:16:09 +0000 Subject: [PATCH 011/279] [ROCm] Fix for the broken ROCm CSB. The following commit breaks the --config=rocm build https://github.com/tensorflow/tensorflow/commit/921003e1c4ff0421b34a3db7ddd338eb376d213f The above commit adds `//tensorflow/core/profiler/lib:traceme` as a build dependency on the CUDA side but not on the ROCm side, leading to bazel errors during the ROCm build. The "fix" is the to add the dependency on the ROCm side as well. --- tensorflow/core/nccl/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 2182e6a80b6..cb00208b9cd 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -31,7 +31,6 @@ cc_library( copts = tf_copts(), deps = if_cuda([ "@local_config_nccl//:nccl", - "//tensorflow/core/profiler/lib:traceme", ]) + if_rocm([ "@local_config_rocm//rocm:rccl", "//tensorflow/core:gpu_runtime", @@ -43,6 +42,7 @@ cc_library( "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor", + "//tensorflow/core/profiler/lib:traceme", ]), alwayslink = 1, ) From b37a8f1f4c13719777a88b3f4c4679b5bc75bcdc Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Thu, 14 Nov 2019 18:55:40 +0000 Subject: [PATCH 012/279] Revert "Don't include traceme.h for ROCM." This reverts commit 22a97f1ef5a8b10c72ba1ec8b463f063c6e5f22d. THe reverted commit attempts to fix the ROCm build, but fails to do so. It merely trades bazel dependency error for compile time errors like the following: ``` tensorflow/core/nccl/nccl_manager.cc: In member function 'void tensorflow::NcclManager::LoopKernelLaunches(tensorflow::NcclManager::NcclStream*)': tensorflow/core/nccl/nccl_manager.cc:689:9: error: 'profiler' has not been declared profiler::TraceMe trace_me("ncclAllReduce"); ^ tensorflow/core/nccl/nccl_manager.cc:718:9: error: 'profiler' has not been declared profiler::TraceMe trace_me("ncclBroadcast"); ^ tensorflow/core/nccl/nccl_manager.cc:729:9: error: 'profiler' has not been declared profiler::TraceMe trace_me("ncclReduce"); ^ ... ``` --- tensorflow/core/nccl/nccl_manager.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index c3d6af92f93..2d799d93a6d 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -21,9 +21,9 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/profiler/lib/traceme.h" #if GOOGLE_CUDA #include "tensorflow/core/platform/cuda.h" -#include "tensorflow/core/profiler/lib/traceme.h" #elif TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" #endif From 340d3337d86c08911c6abce34ec0e449411d223e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 9 Oct 2019 17:01:00 +0900 Subject: [PATCH 013/279] minor spelling tweaks --- .../mlir/lite/flatbuffer_translate.cc | 2 +- tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 8 +++--- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 4 +-- .../quantization/import_quant_stats_pass.cc | 2 +- .../mlir/lite/quantization/quantization.td | 2 +- .../lite/quantization/quantization_traits.h | 2 +- .../lite/quantization/quantization_utils.h | 4 +-- .../mlir/lite/tests/canonicalize.mlir | 8 +++--- .../mlir/lite/tests/extract-ophint.mlir | 4 +-- tensorflow/compiler/mlir/lite/tests/ops.mlir | 6 ++--- .../compiler/mlir/lite/tests/optimize.mlir | 2 +- .../compiler/mlir/lite/tf_tfl_translate_cl.cc | 2 +- .../mlir/lite/transforms/extract_ophint.cc | 26 +++++++++---------- .../transforms/legalize_ophint_func_op.cc | 12 ++++----- .../prepare_composite_functions_tf.cc | 4 +-- .../mlir/lite/transforms/prepare_quantize.cc | 2 +- .../mlir/lite/transforms/prepare_tf.cc | 2 +- .../lite/transforms/unroll_batch_matmul.cc | 2 +- .../compiler/mlir/lite/utils/lstm_utils.h | 2 +- tensorflow/compiler/mlir/runlit.site.cfg.py | 2 +- .../tensorflow/ir/dialect_registration.cc | 2 +- .../compiler/mlir/tensorflow/ir/tf_op_base.td | 2 +- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 2 +- .../compiler/mlir/tensorflow/ir/tf_ops.h | 2 +- .../tensorflow/tests/cluster_formation.mlir | 6 ++--- .../tests/tf_executor_ops_invalid.mlir | 2 +- .../transforms/tf_graph_optimization_pass.cc | 2 +- .../mlir/tensorflow/translate/import_model.cc | 2 +- .../utils/compile_mlir_util_test.cc | 4 +-- tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 2 +- .../compiler/mlir/xla/ir/hlo_ops_base.td | 2 +- .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 2 +- 32 files changed, 64 insertions(+), 64 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 28d70375938..1e6aea7149e 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -381,7 +381,7 @@ class Translator { const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); // Returns opcode index for op identified by the op_name, if already - // available. Otherwise, creates a new OperactorCode using the given `builtin` + // available. Otherwise, creates a new OperatorCode using the given `builtin` // operator and associates it with `op_name`. uint32_t GetOpcodeIndex(const std::string& op_name, tflite::BuiltinOperator builtin); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 0549eadc88a..41d0c4129ec 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -857,8 +857,8 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // // => Value [5, 8, 9] // TODO(b/133341698): Move to tablegen when variadic is supported. -struct RemoveRedunantUnpackPack : public RewritePattern { - explicit RemoveRedunantUnpackPack(MLIRContext *context) +struct RemoveRedundantUnpackPack : public RewritePattern { + explicit RemoveRedundantUnpackPack(MLIRContext *context) : RewritePattern(PackOp::getOperationName(), 2, context) {} PatternMatchResult matchAndRewrite(Operation *op, @@ -896,7 +896,7 @@ struct RemoveRedunantUnpackPack : public RewritePattern { void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1041,7 +1041,7 @@ struct DropFakeQuant : public RewritePattern { } void rewrite(Operation *op, PatternRewriter &rewriter) const override { - // Replace the matched FakeQuantOp by its primiary operand. + // Replace the matched FakeQuantOp by its primary operand. rewriter.replaceOp(op, op->getOperand(0)); } }; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 2eec45e3203..767a04eb7fa 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -35,7 +35,7 @@ def TFL_Dialect : Dialect { Invariants: * All values are of Tensor type (in particular, scalars are - represented using zero-dimentional tensors); + represented using zero-dimensional tensors); }]; let cppNamespace = "TFL"; @@ -581,7 +581,7 @@ def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">; def TFL_FCWO_Shuffled4x16i8 : StrEnumAttrCase<"SHUFFLED4x16INT8">; def TFL_FullyConnectedOptionsWeightFormatAttr : - StrEnumAttr<"FullyConectedOptionsWeightsFormat", + StrEnumAttr<"FullyConnectedOptionsWeightsFormat", "fully connected options weights format", [ TFL_FCWO_Default, TFL_FCWO_Shuffled4x16i8 ]>; diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 4beb4ef9ecf..0326d122c07 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -82,7 +82,7 @@ class ImportQuantStatsPass : public FunctionPass { res->getType().cast().getElementType().isa(); } - // A method to retrive the name for the given op. + // A method to retrieve the name for the given op. OperationToName op_to_name_; // We split the normal names and regex names, since the former can use hash diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization.td b/tensorflow/compiler/mlir/lite/quantization/quantization.td index 996cbbad56b..9c11e7e5ab2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/quantization.td @@ -107,7 +107,7 @@ class AffineOpCoefficient : NativeOpTrait< StrJoinInt<[dim, index]>.result, ">::Impl")>; -// Specify this trait if the op doesn't have quantizable ouput. We shouldn't +// Specify this trait if the op doesn't have quantizable output. We shouldn't // apply quantization on this op. def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index 9f027d27bc2..3830d11afe4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -54,7 +54,7 @@ class SameOperandsAndResultsScale // OpTrait::quant::FixedResultUniformScale< // 8, -128, 390625, -8, 0, 255, false>::Impl> { // -// TODO(fengliuai): create a better way to epxress floating point scale in the +// TODO(fengliuai): create a better way to express floating point scale in the // template argument list. template diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 42deac24d56..ab7ad4ff91b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -133,10 +133,10 @@ struct ConvertStatsToQDQs : public OpRewritePattern { // quantization parameters are annotated by the Q/DQ op pairs. Each // matched pattern are rewritten by its quantized alternatives. // -// The concret pattern, extends from this base pattern, can specify whether it +// The concrete pattern, extends from this base pattern, can specify whether it // allows "hybrid" operands or results. These "hybrid" operands and results // don't have quantization parameters propagated to, so will be in float in the -// quantized results. The concret pattern should define the following two +// quantized results. The concrete pattern should define the following two // functions: // // bool AllowHybridOperand() const diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index 3130c5c2042..ef77288ad27 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -114,8 +114,8 @@ func @fakequant_notdropfakequant(tensor, f32, f32) -> tensor { // ----- -// CHECK-LABEL: @RemoveRedunantUnpackPack -func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> { +// CHECK-LABEL: @RemoveRedundantUnpackPack +func @RemoveRedundantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> { %0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>) %1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>) return %1: tensor<2x5xf32> @@ -125,8 +125,8 @@ func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> { // ----- -// CHECK-LABEL: @RemoveRedunantPack -func @RemoveRedunantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) { +// CHECK-LABEL: @RemoveRedundantPack +func @RemoveRedundantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) { %0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>) %1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>) return %1, %0#0: tensor<2x5xf32>, tensor<5xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir index 215c2d6d94e..bde800897c5 100644 --- a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir +++ b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir @@ -106,8 +106,8 @@ func @extractStackInputOutputOphint() { // CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32> // CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> // CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -// CHECK-DAG: %[[OUPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK-DAG: %[[OUPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK-DAG: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK-DAG: %[[OUTPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 84a2af960a7..b6353faa147 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -440,7 +440,7 @@ func @testEluI32(%arg0: tensor) -> tensor { // ----- -func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { +func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { // CHECK: "NONE" %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32> // CHECK: "RELU" @@ -458,7 +458,7 @@ func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) - // ----- -func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { +func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}} %0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32> return %0: tensor<4xi32> @@ -1079,7 +1079,7 @@ func @testConcatInvalidOperandDimSize(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3x // ----- -func @testConcatInvalidOperandDimSizeComaredToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor { +func @testConcatInvalidOperandDimSizeComparedToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor { // expected-error @+1 {{'tfl.concatenation' op dimension size of dimension #1 of operand #1 must be equal to dimension size of dimension #1 of operand #0, expected 2, got 3}} %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x3xi32>) -> tensor return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index b7df64a6e06..3297e30c288 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -278,7 +278,7 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x %cst2 = constant dense<3.0> : tensor<112x2xf32> %0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> - // We cannot fuse this tfl.mul into the preceding conv op becuase %cst2 is not broadcast-compatible to %cst0. + // We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0. %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32> return %1 : tensor<1x112x112x2xf32> diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc index b7de1acd41b..b0b441a1b60 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc @@ -30,7 +30,7 @@ opt output_file_name("o", llvm::cl::desc(""), opt use_splatted_constant( "use-splatted-constant", llvm::cl::desc( - "Replace constants with randonmly generated splatted tensors"), + "Replace constants with randomly generated splatted tensors"), llvm::cl::init(false), llvm::cl::Hidden); // NOLINTNEXTLINE opt input_mlir( diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index 6c51b5fb1c6..5140a77d8c1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -282,7 +282,7 @@ struct OphintCompositeOp { // Since we have different aggregation strategies, e.g., "first", "last", // "stack". We don't somehow aggregated to get the outputs for the funcOp. // This function is simply compute the RankedTensorType (shape & element type) - std::map GetAggregatedOuputTypes(OpBuilder* builder) { + std::map GetAggregatedOutputTypes(OpBuilder* builder) { std::map aggregated_output_types; for (const auto& kv : outputs) { const AggregatedOperand& operand = kv.second; @@ -387,11 +387,11 @@ struct OphintCompositeOp { // inputs/outputs indicate edges) Assume the graph is acyclic. The preprocess // does the following: // Compute each operations's in-degress (how many input nodes they're taken) -// Get all consumer operations for every operations. (operation_to_ouputs) +// Get all consumer operations for every operations. (operation_to_outputs) // Get the init_queue (those operations will be processed first). void PreprocessTopoSortGraph( Block* block, std::queue* init_queue, - llvm::DenseMap>* operation_to_ouputs, + llvm::DenseMap>* operation_to_outputs, llvm::DenseMap* operation_to_in_degrees) { for (auto& op : *block) { if (&op == block->getTerminator()) continue; @@ -412,9 +412,9 @@ void PreprocessTopoSortGraph( } operation_to_in_degrees->try_emplace(&op, input_ops.size()); for (auto* input_op : input_ops) { - auto preceeding_op_it = operation_to_ouputs->find(input_op); - if (preceeding_op_it == operation_to_ouputs->end()) { - auto result = operation_to_ouputs->try_emplace( + auto preceeding_op_it = operation_to_outputs->find(input_op); + if (preceeding_op_it == operation_to_outputs->end()) { + auto result = operation_to_outputs->try_emplace( input_op, llvm::DenseSet()); preceeding_op_it = result.first; } @@ -442,19 +442,19 @@ bool IsSideEffectOp(Operation* op) { // Also assume the block has no arguments. LogicalResult TopoSortOperations(OpBuilder* builder) { std::queue init_queue; - llvm::DenseMap> operation_to_ouputs; + llvm::DenseMap> operation_to_outputs; llvm::DenseMap operation_to_in_degrees; std::vector sorted_ops; PreprocessTopoSortGraph(builder->getBlock(), &init_queue, - &operation_to_ouputs, &operation_to_in_degrees); + &operation_to_outputs, &operation_to_in_degrees); while (!init_queue.empty()) { Operation* current_op = init_queue.front(); init_queue.pop(); sorted_ops.push_back(current_op); - auto current_op_to_output_it = operation_to_ouputs.find(current_op); - if (current_op_to_output_it == operation_to_ouputs.end()) { + auto current_op_to_output_it = operation_to_outputs.find(current_op); + if (current_op_to_output_it == operation_to_outputs.end()) { continue; } for (Operation* output_op : current_op_to_output_it->second) { @@ -467,7 +467,7 @@ LogicalResult TopoSortOperations(OpBuilder* builder) { operation_to_in_degrees.erase(output_op_it); } } - operation_to_ouputs.erase(current_op_to_output_it); + operation_to_outputs.erase(current_op_to_output_it); } // Before we performs the sort. We need to make sure we didn't mess the @@ -629,11 +629,11 @@ LogicalResult ConvertOphintToStub(StringRef stub_name, // Step 4, get aggregated output types. const std::map& aggregated_output_types = - ophint_composite_op.GetAggregatedOuputTypes(builder); + ophint_composite_op.GetAggregatedOutputTypes(builder); // Step 5, create & place the fused op and rewire the inputs. // Here we use a funcOp to represent the fused op. This "funcOp" will be - // coonverted to other ops (like UnidirectionalSequenceRNNOp) in the + // converted to other ops (like UnidirectionalSequenceRNNOp) in the // legalization phase. Operation* inserted_before_op = ophint_composite_op.GetFirstOutputOp(); Operation* fused_op = BuildFusedFuncOp( diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc index 8d3a78f49fe..ed3a9ea5000 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc @@ -191,10 +191,10 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op, return success(); } -LogicalResult ConvertTfLiteFusedOpIfAvaiable(StringRef func_name, - FuncOp composite_func_op, - CallOp call_op, - OpBuilder* builder) { +LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name, + FuncOp composite_func_op, + CallOp call_op, + OpBuilder* builder) { Operation* fused_op = nullptr; if (func_name == kUnidirectionalSequenceRnn) { // TODO(renjieliu): Validate the func op inputs. @@ -243,8 +243,8 @@ LogicalResult ConvertCallOps(llvm::StringMap* composite_func_ops, StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName) .cast() .getValue(); - if (failed(ConvertTfLiteFusedOpIfAvaiable(func_name, composite_func_op, - call_op, &builder))) + if (failed(ConvertTfLiteFusedOpIfAvailable(func_name, composite_func_op, + call_op, &builder))) return failure(); composite_func_ops->erase(it); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 20ef595f0d2..9aa5c56b9a1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -67,11 +67,11 @@ class ConvertEmbeddedLookupFunc { if (func_.getNumArguments() != 2) { return func_.emitError() << "Invalid number of arguments in the embedding " - "matmal composite function"; + "matmul composite function"; } if (func_.getType().getNumResults() != 1) { return func_.emitError() << "Invalid number of results in the embedding " - "matmal composite function"; + "matmul composite function"; } return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index e5069093c3b..0b702e13a75 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -34,7 +34,7 @@ limitations under the License. // NOLINTNEXTLINE static llvm::cl::list quantize_whitelist( "tfl-test-quantize-whitelist", llvm::cl::value_desc("list"), - llvm::cl::desc("comma seprarated list of whitelisted functions to be " + llvm::cl::desc("comma separated list of whitelisted functions to be " "quantized. Only used in tests"), llvm::cl::CommaSeparated); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 3da45e930d3..823efdc3ef5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -400,7 +400,7 @@ class ConvertTFDepthwiseConv2dNative } }; -// StridedSlice can have complicated atributes like begin_axis_mask, +// StridedSlice can have complicated attributes like begin_axis_mask, // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These // masks will complicate the strided_slice computation logic, we can simplify // the logic by inserting a reshape op to pad the inputs so strided_slice can diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc index 8b354cc9875..61d33a5233e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc @@ -247,7 +247,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( } if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) { - // Input dimensions must be compatible for multipication. + // Input dimensions must be compatible for multiplication. return this->matchFailure(); } diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h index 200359eacf6..71905e1770b 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -125,7 +125,7 @@ class ConvertLSTMCellSimpleToFusedLSTM { Value* input2cell_; Value* input2output_; - // reccurrent -> cifg + // recurrent -> cifg Value* rec2input_; Value* rec2forget_; Value* rec2cell_; diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 1ae06e36c25..e14199ed43b 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -28,7 +28,7 @@ config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm') config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR']) config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'local_config_mlir') -# TODO(jpienaar): Replace with sufffices in build rule. +# TODO(jpienaar): Replace with suffices in build rule. config.suffixes = ['.td', '.mlir', '.pbtxt'] mlir_tf_tools_dirs = [ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc index 4bac3be1d1e..ac468d9810c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc @@ -26,7 +26,7 @@ static DialectRegistration tf_control_flow_ops; static DialectRegistration tf_ops; static DialectRegistration - tf_excutor_dialect; + tf_executor_dialect; static DialectRegistration tf_device_dialect; static DialectRegistration diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index e324b70325e..846ea59cebe 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -41,7 +41,7 @@ This dialect maps to TensorFlow operations. Invariants: * All values are of Tensor type (in particular, scalars are - represented using zero-dimentional tensors); + represented using zero-dimensional tensors); TODO: Make invariants more structured so that we can reference them in ops. }]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 96e310a2a27..1073bfe0f31 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -479,7 +479,7 @@ void ConstOp::build(Builder *builder, OperationState &result, Attribute value) { } else if (value.isa() || value.isa() || value.isa()) { // All TensorFlow types must be tensor types. In the build() method, - // we want to provide more flexiblity by allowing attributes of scalar + // we want to provide more flexibility by allowing attributes of scalar // types. But we need to wrap it up with ElementsAttr to construct // valid TensorFlow constants. type = RankedTensorType::get(/*shape=*/{}, value.getType()); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 7aa7f670e31..74d0f3e1dc4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -41,7 +41,7 @@ class TensorFlowDialect : public Dialect { static StringRef getDialectNamespace() { return "tf"; } - // Gradient attribute ("tf.gradient") in the list of NamedAttibutes in a + // Gradient attribute ("tf.gradient") in the list of NamedAttributes in a // function references to its gradient function. This attribute in TensorFlow // Dialect is used to model TF GradientDef. GetGradientAttrName() returns the // string description of gradient attribute. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir index 0356070bb0a..b9deb2799c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir @@ -232,12 +232,12 @@ module { // ----- -// Single device with non-continous instructions in original block. +// Single device with non-continuous instructions in original block. module { - // CHECK-LABEL: func @noncontinoussinglecluster + // CHECK-LABEL: func @noncontinuoussinglecluster // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @noncontinoussinglecluster(%arg0: tensor) -> tensor { + func @noncontinuoussinglecluster(%arg0: tensor) -> tensor { %0 = tf_executor.graph { %1:2 = tf_executor.island { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index 1570c0cab33..62f0faec9a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -535,7 +535,7 @@ func @invalid_merge(%arg0: tensor<*x!tf.resource>, %arg1: tensor<4x!tf.resource> // ----- // Check that if result is a ref type, all operands need to be ref too. -func @inavlid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> { +func @invalid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> { %result = tf_executor.graph { %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor, !tf_executor.control) // expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 4b74f3e6ca3..2eb12c80efe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -141,7 +141,7 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); // NOLINTNEXTLINE static llvm::cl::list cl_pass_list( "graph-passes", llvm::cl::value_desc("list"), - llvm::cl::desc("comma seprarated list of GraphOptimizationPass to run."), + llvm::cl::desc("comma separated list of GraphOptimizationPass to run."), llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory)); class GraphOptByNamePass : public GraphOptPass { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index a3880827a97..2660e2d855d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -534,7 +534,7 @@ Status ImporterBase::AddNodesToShapeRefiner() { auto node_name = node->op_def().name(); if (node_name != "Placeholder" && node_name != "LegacyFedInput" && node_name != FunctionLibraryDefinition::kArgOp) { - // We do not handle the case where the input node has multple outputs + // We do not handle the case where the input node has multiple outputs if (node->num_outputs() > 1) { return errors::FailedPrecondition(absl::StrCat( "Input arrays can only have op with single output. Node op:", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 7a0c8d62d6a..4ed41bf2054 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -36,7 +36,7 @@ xla::StatusOr TestShapeRepresentation(const TensorShape& shape, return xla_shape; } -TEST(CompileSerializedMlirToXlaHloTest, InvalidSerliazedMlirModule) { +TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { string invalid_mlir_module = "totally @invalid MLIR module {here} <-"; std::vector arg_shapes; XlaCompiler::CompilationResult compilation_result; @@ -100,7 +100,7 @@ ENTRY %main.5 (arg_tuple.1: (f32[], f32[])) -> f32[] { xla::ShapeUtil::MakeTupleShape({output_shape}); EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); - // Expect exactly 1 OutputDescrpition. + // Expect exactly 1 OutputDescription. EXPECT_EQ(compilation_result.outputs.size(), 1); const XlaCompiler::OutputDescription& output_desc = compilation_result.outputs.front(); diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index a046e3acec2..e1c5d3dea58 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -171,7 +171,7 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) { } else if (value.isa() || value.isa() || value.isa()) { // All XLA types must be tensor types. In the build() method, we want to - // provide more flexiblity by allowing attributes of scalar types. But we + // provide more flexibility by allowing attributes of scalar types. But we // need to wrap it up with ElementsAttr to construct valid XLA constants. type = RankedTensorType::get(/*shape=*/{}, value.getType()); value = DenseElementsAttr::get(type.cast(), value); diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index a97fbd672b5..d7c2108625c 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -655,7 +655,7 @@ class BASE_HLO_ClampOp { } class BASE_HLO_ConcatenateOp { - string summary = "XLA's concantenate op"; + string summary = "XLA's concatenate op"; string description = [{ Concatenates a set of tensors along the specified dimension. diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 41b1940e1e1..52b4633cce7 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -680,7 +680,7 @@ LogicalResult ConvertToHloModule::LowerFunctionCall( LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { if (lowered_computation_.count(f)) return success(); if (f.getBlocks().size() != 1) { - return f.emitError("only single block Function suppored"); + return f.emitError("only single block Function supported"); } // Create a sub-builder if this is not the main function. From d1f49f699f0691f349b10e4e75ade0a36b712af0 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 8 Nov 2019 19:53:46 -0800 Subject: [PATCH 014/279] Move TensorRT builder configs to converter build function --- .../tf2tensorrt/convert/convert_nodes.cc | 59 ++++++++++--------- .../tf2tensorrt/convert/convert_nodes.h | 5 +- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index c800e5005d2..0dca101ed79 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1369,13 +1369,33 @@ Status Converter::RenameAndMarkOutputTensors( } Status Converter::BuildCudaEngine( - TrtUniquePtrType* engine) { - VLOG(1) << "Starting engine creation"; + TrtUniquePtrType* engine, + int max_batch_size, size_t max_workspace_size_bytes, + nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator) { + VLOG(1) << "Configuring TensorRT builder"; + trt_builder_->setMaxBatchSize(max_batch_size); + trt_builder_->setMaxWorkspaceSize(max_workspace_size_bytes); + trt_builder_->setGpuAllocator(allocator); + if (precision_mode_ == TrtPrecisionMode::FP16) { + trt_builder_->setFp16Mode(true); + } else if (precision_mode_ == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + trt_builder_->setFp16Mode(true); + trt_builder_->setInt8Mode(true); + if (use_calibration_) { + trt_builder_->setInt8Calibrator(calibrator); + } else { + trt_builder_->setInt8Calibrator(nullptr); + } + } + + VLOG(1) << "Building TensorRT engine"; engine->reset(trt_builder_->buildCudaEngine(*network())); if (engine->get() == nullptr) { return errors::Internal("Failed to build TensorRT engine"); } - VLOG(1) << "Finished conversion"; return Status::OK(); } @@ -5620,37 +5640,17 @@ Status ConvertGraphDefToEngine( engine->reset(); if (convert_successfully) *convert_successfully = false; - // Create the builder. + VLOG(1) << "Creating TensorRT builder"; TrtUniquePtrType builder( nvinfer1::createInferBuilder(*trt_logger)); - builder->setMaxBatchSize(max_batch_size); - builder->setMaxWorkspaceSize(max_workspace_size_bytes); - builder->setGpuAllocator(allocator); - if (precision_mode == TrtPrecisionMode::FP16) { - builder->setFp16Mode(true); - } else if (precision_mode == TrtPrecisionMode::INT8) { - // Setting FP16 mode as well allows TRT to also consider FP16 kernels and - // use them in situations where they are faster than INT8 or where INT8 is - // not supported for a given layer. - builder->setFp16Mode(true); - builder->setInt8Mode(true); - if (use_calibration) { - builder->setInt8Calibrator(calibrator); - } else { - builder->setInt8Calibrator(nullptr); - } - } - // Build the network - if (VLOG_IS_ON(1)) { - string mode_str; - TF_RETURN_IF_ERROR(TrtPrecisionModeToName(precision_mode, &mode_str)); - VLOG(1) << "Starting engine conversion, precision mode: " << mode_str; - } + VLOG(1) << "Creating converter and TensorRT network"; auto statusor = Converter::Create(builder.get(), precision_mode, use_calibration, trt_logger); TF_RETURN_IF_ERROR(statusor.status()); auto converter = std::move(statusor.ValueOrDie()); + + VLOG(1) << "Starting to convert TensorFlow ops to TensorRT layers"; std::vector output_tensors; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { @@ -5737,7 +5737,10 @@ Status ConvertGraphDefToEngine( converter->MaybeApplyQuantizationRanges(); // Build the engine. - TF_RETURN_IF_ERROR(converter->BuildCudaEngine(engine)); + TF_RETURN_IF_ERROR(converter->BuildCudaEngine( + engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator)); + + VLOG(1) << "Finished conversion"; return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 00099396308..b3dc37322ea 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -467,7 +467,10 @@ class Converter { const std::vector& output_tensors); // Build a TRT engine using the created network. - Status BuildCudaEngine(TrtUniquePtrType* engine); + Status BuildCudaEngine(TrtUniquePtrType* engine, + int max_batch_size, size_t max_workspace_size_bytes, + nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator); ////////////////////////////////////////////////////////////////////////////// // Methods used by op converters to convert individual TF node and add layers From 9d5cbf4e94e08f0e598a36fbc162417d7b031b66 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Thu, 14 Nov 2019 16:25:42 -0800 Subject: [PATCH 015/279] Move builder to Converter class --- .../tf2tensorrt/convert/convert_nodes.cc | 25 ++++++++----------- .../tf2tensorrt/convert/convert_nodes.h | 2 +- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 0dca101ed79..ea608568db9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1217,18 +1217,16 @@ static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { // static StatusOr> Converter::Create( - nvinfer1::IBuilder* trt_builder, TrtPrecisionMode precision_mode, - bool use_calibration, nvinfer1::ILogger* trt_logger) { + TrtPrecisionMode precision_mode, bool use_calibration, + nvinfer1::ILogger* trt_logger) { std::unique_ptr converter = absl::WrapUnique( - new Converter(trt_builder, precision_mode, use_calibration, trt_logger)); + new Converter(precision_mode, use_calibration, trt_logger)); TF_RETURN_IF_ERROR(converter->Init()); return converter; } -Converter::Converter(nvinfer1::IBuilder* trt_builder, - TrtPrecisionMode precision_mode, bool use_calibration, +Converter::Converter(TrtPrecisionMode precision_mode, bool use_calibration, nvinfer1::ILogger* trt_logger) - : trt_builder_(trt_builder), precision_mode_(precision_mode), use_calibration_(use_calibration) { InitializeTrtPlugins(trt_logger); @@ -1236,7 +1234,10 @@ Converter::Converter(nvinfer1::IBuilder* trt_builder, } Status Converter::Init() { - // Create the network. + VLOG(1) << "Creating TensorRT builder"; + trt_builder_.reset(nvinfer1::createInferBuilder(*trt_logger)); + + VLOG(1) << "Creating TensorRT network"; trt_network_.reset(trt_builder_->createNetwork()); if (!trt_network_) { return errors::Internal("Failed to create TensorRT network object"); @@ -5640,13 +5641,9 @@ Status ConvertGraphDefToEngine( engine->reset(); if (convert_successfully) *convert_successfully = false; - VLOG(1) << "Creating TensorRT builder"; - TrtUniquePtrType builder( - nvinfer1::createInferBuilder(*trt_logger)); - - VLOG(1) << "Creating converter and TensorRT network"; - auto statusor = Converter::Create(builder.get(), precision_mode, - use_calibration, trt_logger); + // Creating converter, TensorRT builder and network + auto statusor = Converter::Create( + precision_mode, use_calibration, trt_logger); TF_RETURN_IF_ERROR(statusor.status()); auto converter = std::move(statusor.ValueOrDie()); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index b3dc37322ea..d4c24ef41d9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -563,7 +563,7 @@ class Converter { std::unordered_map trt_tensors_; // The TRT builder used to create the network and build the engine. Not owned. - nvinfer1::IBuilder* trt_builder_; + TrtUniquePtrType trt_builder_; // The TRT network being built. TrtUniquePtrType trt_network_; From d6dbc7237fb0a0665911d7efa82fa60231a7bcb2 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 19 Nov 2019 11:03:55 -0800 Subject: [PATCH 016/279] Fix function signatures --- tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc | 6 +++--- tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index ea608568db9..5e7468f48a7 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1221,19 +1221,19 @@ StatusOr> Converter::Create( nvinfer1::ILogger* trt_logger) { std::unique_ptr converter = absl::WrapUnique( new Converter(precision_mode, use_calibration, trt_logger)); - TF_RETURN_IF_ERROR(converter->Init()); + TF_RETURN_IF_ERROR(converter->Init(trt_logger)); return converter; } Converter::Converter(TrtPrecisionMode precision_mode, bool use_calibration, nvinfer1::ILogger* trt_logger) - precision_mode_(precision_mode), + : precision_mode_(precision_mode), use_calibration_(use_calibration) { InitializeTrtPlugins(trt_logger); this->RegisterOpConverters(); } -Status Converter::Init() { +Status Converter::Init(nvinfer1::ILogger* trt_logger) { VLOG(1) << "Creating TensorRT builder"; trt_builder_.reset(nvinfer1::createInferBuilder(*trt_logger)); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index d4c24ef41d9..7b2f25d7dc1 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -446,7 +446,7 @@ class Converter { }; static StatusOr> Create( - nvinfer1::IBuilder* trt_builder, TrtPrecisionMode precision_mode, + TrtPrecisionMode precision_mode, bool use_calibration, nvinfer1::ILogger* trt_logger); ////////////////////////////////////////////////////////////////////////////// @@ -529,10 +529,10 @@ class Converter { const nvinfer1::Dims& dims); private: - Converter(nvinfer1::IBuilder* trt_builder, TrtPrecisionMode precision_mode, + Converter(TrtPrecisionMode precision_mode, bool use_calibration, nvinfer1::ILogger* trt_logger); - Status Init(); + Status Init(nvinfer1::ILogger* trt_logger); // Verify the provided batch_size is consistent with batch_size_ and update it // if necessary. From 51cec9092e73f22cabaca3c7e0ba7154b1a4b27d Mon Sep 17 00:00:00 2001 From: Mohamed Nour Abouelseoud Date: Tue, 27 Aug 2019 16:29:28 +0100 Subject: [PATCH 017/279] [Lite] Support Int8 Unpack Operator Added support for Unpack Operator Added relevant tests. --- tensorflow/lite/tools/optimize/BUILD | 1 + .../lite/tools/optimize/operator_property.cc | 8 ++- .../tools/optimize/quantize_model_test.cc | 50 ++++++++++++++++++ tensorflow/lite/tools/optimize/test_util.cc | 2 + tensorflow/lite/tools/optimize/test_util.h | 3 ++ .../lite/tools/optimize/testdata/unpack.bin | Bin 0 -> 616 bytes 6 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tensorflow/lite/tools/optimize/testdata/unpack.bin diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 5548567c2c7..ea554b08522 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -242,6 +242,7 @@ tf_cc_test( "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", "//tensorflow/lite/tools/optimize:testdata/split.bin", + "//tensorflow/lite/tools/optimize:testdata/unpack.bin", ], tags = [ "tflite_not_portable_android", diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index b284e025159..aeb653f6a38 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -70,9 +70,9 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.version = 2; break; case BuiltinOperator_SPLIT: - property.arbitrary_outputs = true; // We skip input 0 since it is the split dim which is not real valued. property.inputs = {{1, {}}}; + property.arbitrary_outputs = true; property.restrict_same_input_output_scale = true; property.version = 2; break; @@ -383,6 +383,12 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.restrict_same_input_output_scale = true; property.version = 2; break; + case BuiltinOperator_UNPACK: + property.inputs = {{0, {}}}; + property.arbitrary_outputs = true; + property.restrict_same_input_output_scale = true; + property.version = 1; + break; default: // No quantized implementation exists for this operation. property.quantizable = false; diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 5e736d96e5b..3e708144d8c 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -1124,6 +1124,56 @@ TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) { } } +class QuantizeUnpackTest : public QuantizeModelTest { +protected: + QuantizeUnpackTest() { + input_model_ = ReadModel(internal::kModelWithUnpack); + readonly_model_ = input_model_->GetModel(); + readonly_model_->UnPackTo(&model_); + } +}; + +TEST_F(QuantizeUnpackTest, VerifyUnpack) { + auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + + ASSERT_EQ(kTfLiteOk, status); + + const auto subgraph = model_.subgraphs[0].get(); + auto op = subgraph->operators[1].get(); + + auto float_graph = readonly_model_->subgraphs()->Get(0); + + ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + BuiltinOperator_UNPACK); + + // Get unpack input and output tensors + auto unpack_input = subgraph->tensors[op->inputs[0]].get(); + auto unpack_output_0 = subgraph->tensors[op->outputs[0]].get(); + auto unpack_output_1 = subgraph->tensors[op->outputs[1]].get(); + + // Verify Unpack input is quantized. + ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(), + TensorType_FLOAT32); + EXPECT_EQ(unpack_input->type, TensorType_INT8); + + // The model should only have one input and 2 outputs. + EXPECT_EQ(subgraph->inputs.size(), 1); + EXPECT_EQ(subgraph->outputs.size(), 2); + + // Ensure quantization parameters before and after unpack + // are preserved after quantization for all outputs of + // unpack. + EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0], + unpack_output_0->quantization->scale[0]); + EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0], + unpack_output_1->quantization->scale[0]); + EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0], + unpack_output_0->quantization->zero_point[0]); + EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0], + unpack_output_1->quantization->zero_point[0]); + +} + } // namespace } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc index 74524a18081..ecceacb278c 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/lite/tools/optimize/test_util.cc @@ -52,6 +52,8 @@ const char* kModelSplit = "split.bin"; const char* kLstmCalibrated = "lstm_calibrated.bin"; const char* kLstmQuantized = "lstm_quantized.bin"; +const char* kModelWithUnpack = "unpack.bin"; + int FailOnErrorReporter::Report(const char* format, va_list args) { char buf[1024]; vsnprintf(buf, sizeof(buf), format, args); diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h index 12c46aa882b..7690ab212cf 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/lite/tools/optimize/test_util.h @@ -80,6 +80,9 @@ extern const char* kModelSplit; extern const char* kLstmCalibrated; extern const char* kLstmQuantized; +// Test model with an unpack op. +extern const char* kModelWithUnpack; + // An error reporter that fails on testing. class FailOnErrorReporter : public ErrorReporter { public: diff --git a/tensorflow/lite/tools/optimize/testdata/unpack.bin b/tensorflow/lite/tools/optimize/testdata/unpack.bin new file mode 100644 index 0000000000000000000000000000000000000000..72e58bfa1eabe61cb31fc3bbcf9dab9211610a3a GIT binary patch literal 616 zcmaJ;F;Buk7=3D`s0|YBz|et#ku)?q7~?>KaWa8{!HpnoVk4BK6yxYm@FzGj_zRr; zQ4S8q`rgqJ7ro@&_uad@cklbo12DY0J^_y4!9o=d>Tt<6$N)P;dmDIWd?Zc@i`Xh+ z=R41-gKM4E|@ImgA_%eHz zId{nYVZYaJ_T~vSvnce<`*rdTy_T*WuAhjE2nlJ+XM=FftmYqBKe(F#E@_4|#!E`){ Y?&i}lO3i&~$fx{DXWc*B^pA_1AImUHkN^Mx literal 0 HcmV?d00001 From 6cdf78e1a9440fb278f559c858c4c1721fa53d5d Mon Sep 17 00:00:00 2001 From: Amit Kumar Jaiswal Date: Sat, 9 Nov 2019 12:54:05 +0000 Subject: [PATCH 018/279] Update ops --- tensorflow/python/ops/check_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 34106f61fd8..ee7d1097b78 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1594,10 +1594,10 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None, ```python tf.assert_shapes([ - (x: ('N', 'Q')), - (y: ('N', 'D')), - (param: ('Q',)), - (scalar: ()), + (x, ('N', 'Q')), + (y, ('N', 'D')), + (param, ('Q',)), + (scalar, ()), ]) ``` From bea524e28b3aa40c1106c4d769fc5da3e888b0f3 Mon Sep 17 00:00:00 2001 From: Amit Kumar Jaiswal Date: Fri, 15 Nov 2019 19:24:32 +0000 Subject: [PATCH 019/279] Update check_ops.py --- tensorflow/python/ops/check_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index ee7d1097b78..64adfe13163 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1593,12 +1593,12 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None, Example: ```python - tf.assert_shapes([ + tf.assert_shapes({ (x, ('N', 'Q')), (y, ('N', 'D')), (param, ('Q',)), (scalar, ()), - ]) + }) ``` If `x`, `y`, `param` or `scalar` does not have a shape that satisfies From 5ce67c658cb603f10d8e4806a512e2e78944c6c6 Mon Sep 17 00:00:00 2001 From: Amit Kumar Jaiswal Date: Wed, 20 Nov 2019 00:09:14 +0000 Subject: [PATCH 020/279] Update check_ops.py Signed-off: Amit Kumar Jaiswal --- tensorflow/python/ops/check_ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 64adfe13163..3e4af9f8aa2 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1594,11 +1594,12 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None, ```python tf.assert_shapes({ - (x, ('N', 'Q')), - (y, ('N', 'D')), - (param, ('Q',)), - (scalar, ()), + x: ('N', 'Q'), + y: ('N', 'D'), + param: ('Q',), + scalar: (), }) + ``` If `x`, `y`, `param` or `scalar` does not have a shape that satisfies From b2094d4c6e10ec46c2aa3ba3cfffaf0291874869 Mon Sep 17 00:00:00 2001 From: Amit Kumar Jaiswal Date: Wed, 20 Nov 2019 13:45:14 +0000 Subject: [PATCH 021/279] Update tf.assert_shapes --- tensorflow/python/ops/check_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 3e4af9f8aa2..8a23fd37385 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1647,10 +1647,10 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): ```python tf.assert_shapes({ - (x, ('N', 'Q')), - (y, ('N', 'D')), - (param, ('Q',)), - (scalar, ()) + x, ('N', 'Q'), + y, ('N', 'D'), + param, ('Q',), + scalar, (), }) ``` From 77c098e2cd9ce42d68d011c4510bd5ec7e412d93 Mon Sep 17 00:00:00 2001 From: Harry Slatyer Date: Tue, 26 Nov 2019 09:30:01 +1100 Subject: [PATCH 022/279] Always use real exp when scaling in linalg.expm Previously, for matrices of complex type, the linalg.expm implementation would compute 2**squarings as complex type (despite being an entirely real computation). However, there's no GPU kernel for complex exponential. Instead, we can compute 2**squarings as a real type. We still do the same number of casts as before (one), but now the cast happens after the exponential instead of before. As a result of this change, the entire linalg.expm computation can run on GPU. --- tensorflow/python/ops/linalg/linalg_impl.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index 9cdcaee6ac2..18d22968c94 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -298,11 +298,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) u3, v3 = _matrix_exp_pade3(matrix) u5, v5 = _matrix_exp_pade5(matrix) - u7, v7 = _matrix_exp_pade7(matrix / math_ops.pow( - constant_op.constant(2.0, dtype=matrix.dtype), - math_ops.cast( - squarings, - matrix.dtype))[..., array_ops.newaxis, array_ops.newaxis]) + u7, v7 = _matrix_exp_pade7(matrix / math_ops.cast( + math_ops.pow(const(2.0), squarings), + matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis]) conds = (4.258730016922831e-001, 1.880152677804762e+000) u = _nest_where(conds, (u3, u5, u7)) v = _nest_where(conds, (v3, v5, v7)) @@ -315,11 +313,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin u5, v5 = _matrix_exp_pade5(matrix) u7, v7 = _matrix_exp_pade7(matrix) u9, v9 = _matrix_exp_pade9(matrix) - u13, v13 = _matrix_exp_pade13(matrix / math_ops.pow( - constant_op.constant(2.0, dtype=matrix.dtype), - math_ops.cast( - squarings, - matrix.dtype))[..., array_ops.newaxis, array_ops.newaxis]) + u13, v13 = _matrix_exp_pade13(matrix / math_ops.cast( + math_ops.pow(const(2.0), squarings), + matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis]) conds = (1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000) u = _nest_where(conds, (u3, u5, u7, u9, u13)) From cf6f86af774eaa0a64f05ac2ca73a6433775102a Mon Sep 17 00:00:00 2001 From: Duncan Riach Date: Mon, 25 Nov 2019 18:08:55 -0800 Subject: [PATCH 023/279] Relocate comment that applies to both newer and older code. --- tensorflow/core/kernels/resize_bilinear_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc index 0a19a574d7f..d58c52f0a8f 100644 --- a/tensorflow/core/kernels/resize_bilinear_op.cc +++ b/tensorflow/core/kernels/resize_bilinear_op.cc @@ -230,6 +230,7 @@ struct ResizeBilinear { std::vector ys(out_height + 1); std::vector xs(out_width + 1); + // Compute the cached interpolation weights on the x and y dimensions. if (half_pixel_centers) { compute_interpolation_weights(HalfPixelScaler(), out_height, in_height, height_scale, ys.data()); @@ -237,7 +238,6 @@ struct ResizeBilinear { width_scale, xs.data()); } else { - // Compute the cached interpolation weights on the x and y dimensions. compute_interpolation_weights(LegacyScaler(), out_height, in_height, height_scale, ys.data()); compute_interpolation_weights(LegacyScaler(), out_width, in_width, From fae2e9a2949de8697466295e1941f8dd7068e409 Mon Sep 17 00:00:00 2001 From: Amit Kumar Jaiswal Date: Tue, 26 Nov 2019 16:57:04 +0000 Subject: [PATCH 024/279] Update tf.assert --- tensorflow/python/ops/check_ops.py | 40 +++++++++++++++++++----------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 8a23fd37385..839da5fb90f 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1592,15 +1592,27 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None, Example: - ```python - tf.assert_shapes({ - x: ('N', 'Q'), - y: ('N', 'D'), - param: ('Q',), - scalar: (), - }) + >>> n = 10 + >>> q = 3 + >>> d = 7 + >>> x = tf.zeros([n,q]) + >>> y = tf.ones([n,d]) + >>> param = tf.Variable([1.0, 2.0, 3.0]) + >>> scalar = 1.0 + >>> tf.debugging.assert_shapes([ + ... (x, ('N', 'Q')), + ... (y, ('N', 'D')), + ... (param, ('Q',)), + ... (scalar, ()), + ... ]) - ``` + >>> tf.debugging.assert_shapes([ + ... (x, ('N', 'D')), + ... (y, ('N', 'D')) + ... ]) + Traceback (most recent call last): + ... + ValueError: ... If `x`, `y`, `param` or `scalar` does not have a shape that satisfies all specified constraints, `message`, as well as the first `summarize` entries @@ -1646,12 +1658,12 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): Example: ```python - tf.assert_shapes({ - x, ('N', 'Q'), - y, ('N', 'D'), - param, ('Q',), - scalar, (), - }) + tf.assert_shapes([ + (x, ('N', 'Q')), + (y, ('N', 'D')), + (param, ('Q',)), + (scalar, ()) + ]) ``` Example of adding a dependency to an operation: From dbaee60ba2582e717ce2063e2961bb1519a1a8da Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Fri, 12 Jul 2019 15:26:06 -0700 Subject: [PATCH 025/279] Add kZlib to compression namespace --- tensorflow/core/lib/io/compression.cc | 1 + tensorflow/core/lib/io/compression.h | 1 + tensorflow/core/lib/io/record_reader.cc | 2 +- tensorflow/core/lib/io/record_writer.cc | 2 +- 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/lib/io/compression.cc b/tensorflow/core/lib/io/compression.cc index 0aa4caaaef8..116608dcc55 100644 --- a/tensorflow/core/lib/io/compression.cc +++ b/tensorflow/core/lib/io/compression.cc @@ -22,6 +22,7 @@ namespace compression { const char kNone[] = ""; const char kGzip[] = "GZIP"; const char kSnappy[] = "SNAPPY"; +const char kZlib[] = "ZLIB"; } // namespace compression } // namespace io diff --git a/tensorflow/core/lib/io/compression.h b/tensorflow/core/lib/io/compression.h index 10981846d0a..7856ea8bb00 100644 --- a/tensorflow/core/lib/io/compression.h +++ b/tensorflow/core/lib/io/compression.h @@ -23,6 +23,7 @@ namespace compression { extern const char kNone[]; extern const char kGzip[]; extern const char kSnappy[]; +extern const char kZlib[]; } // namespace compression } // namespace io diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 2c24a74f54b..1af81bd902c 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -31,7 +31,7 @@ namespace io { RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions( const string& compression_type) { RecordReaderOptions options; - if (compression_type == "ZLIB") { + if (compression_type == compression::kZlib) { options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; #if defined(IS_SLIM_BUILD) LOG(ERROR) << "Compression is not supported but compression_type is set." diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 405e65a2a6a..52d0ef9a358 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -31,7 +31,7 @@ bool IsZlibCompressed(RecordWriterOptions options) { RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions( const string& compression_type) { RecordWriterOptions options; - if (compression_type == "ZLIB") { + if (compression_type == compression::kZlib) { options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; #if defined(IS_SLIM_BUILD) LOG(ERROR) << "Compression is not supported but compression_type is set." From 7fd47f10140a3965284f700f4a7cfea62a45019f Mon Sep 17 00:00:00 2001 From: Duncan Riach Date: Tue, 26 Nov 2019 15:50:14 -0800 Subject: [PATCH 026/279] Improve resize_bilinear CPU back-prop kernel comment --- tensorflow/core/kernels/resize_bilinear_op.cc | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc index 0a19a574d7f..61d0732245e 100644 --- a/tensorflow/core/kernels/resize_bilinear_op.cc +++ b/tensorflow/core/kernels/resize_bilinear_op.cc @@ -309,13 +309,16 @@ struct ResizeBilinearGrad { output_grad.setZero(); - // Each resized pixel was computed as a weighted average of four input - // pixels. Here we find the pixels that contributed to each output pixel - // and add the corresponding coefficient to the gradient. - // resized(b, y, x, c) = top_left * (1 - y) * (1 - x) - // + top_right * (1 - y) * x - // + bottom_left * y * (1 - x) - // + bottom_right * y * x + // Each resized output pixel was computed as a weighted average of four + // input pixels. Here we find the four input pixel locations that + // contributed to each output pixel and propgate the gradient at the output + // pixel location to each of those four input pixels locations in the same + // proportions that they originally contributed to the output pixel. + // Here is the forward-propagation pseudo-code, for reference: + // resized(b, y, x, c) = top_left * (1 - y) * (1 - x) + // + top_right * (1 - y) * x + // + bottom_left * y * (1 - x) + // + bottom_right * y * x for (Eigen::Index b = 0; b < batch; ++b) { for (Eigen::Index y = 0; y < resized_height; ++y) { const float in_y = scaler(y, height_scale); From b0c6ea8fdeb0cc2ffa892362c6ee6b41556c185b Mon Sep 17 00:00:00 2001 From: Duncan Riach Date: Tue, 26 Nov 2019 15:59:09 -0800 Subject: [PATCH 027/279] Fix small typo --- tensorflow/core/kernels/resize_bilinear_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc index 61d0732245e..46815ccba5d 100644 --- a/tensorflow/core/kernels/resize_bilinear_op.cc +++ b/tensorflow/core/kernels/resize_bilinear_op.cc @@ -312,7 +312,7 @@ struct ResizeBilinearGrad { // Each resized output pixel was computed as a weighted average of four // input pixels. Here we find the four input pixel locations that // contributed to each output pixel and propgate the gradient at the output - // pixel location to each of those four input pixels locations in the same + // pixel location to each of those four input pixel locations in the same // proportions that they originally contributed to the output pixel. // Here is the forward-propagation pseudo-code, for reference: // resized(b, y, x, c) = top_left * (1 - y) * (1 - x) From 860666581f52b75ccb4ec283f99546b86d955925 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Wed, 27 Nov 2019 02:01:27 +0000 Subject: [PATCH 028/279] Fix TensorFlow pip API generation --- tensorflow/api_template.__init__.py | 10 ++++++---- tensorflow/api_template_v1.__init__.py | 9 +++++---- .../python/tools/api/generator/create_python_api.py | 12 +++++++++--- tensorflow/virtual_root_template_v1.__init__.py | 3 --- tensorflow/virtual_root_template_v2.__init__.py | 10 ---------- 5 files changed, 20 insertions(+), 24 deletions(-) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 56d65d45faf..c515cc76b9a 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -119,11 +119,11 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): - for s in _site_packages_dirs: + for _s in _site_packages_dirs: # TODO(gunan): Add sanity checks to loaded modules here. - plugin_dir = _os.path.join(s, 'tensorflow-plugins') - if _fi.file_exists(plugin_dir): - _ll.load_library(plugin_dir) + _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') + if _fi.file_exists(_plugin_dir): + _ll.load_library(_plugin_dir) # Add module aliases if hasattr(_current_module, 'keras'): @@ -136,3 +136,5 @@ if hasattr(_current_module, 'keras'): setattr(_current_module, "optimizers", optimizers) setattr(_current_module, "initializers", initializers) # pylint: enable=undefined-variable + +# __all__ PLACEHOLDER diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 97478a18b8a..2b2899c3fe0 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -132,9 +132,10 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): - for s in _site_packages_dirs: + for _s in _site_packages_dirs: # TODO(gunan): Add sanity checks to loaded modules here. - plugin_dir = _os.path.join(s, 'tensorflow-plugins') - if _fi.file_exists(plugin_dir): - _ll.load_library(plugin_dir) + _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') + if _fi.file_exists(_plugin_dir): + _ll.load_library(_plugin_dir) +# __all__ PLACEHOLDER diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py index 3af677322d6..80f663683c3 100644 --- a/tensorflow/python/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -243,11 +243,12 @@ class _ModuleInitCodeBuilder(object): # from it using * import. Don't need this for lazy_loading because the # underscore symbols are already included in __all__ when passed in and # handled by TFModuleWrapper. + root_module_footer = '' if not self._lazy_loading: underscore_names_str = ', '.join( '\'%s\'' % name for name in self._underscore_names_in_root) - module_text_map[''] = module_text_map.get('', '') + ''' + root_module_footer = ''' _names_with_underscore = [%s] __all__ = [_s for _s in dir() if not _s.startswith('_')] __all__.extend([_s for _s in _names_with_underscore]) @@ -273,7 +274,7 @@ __all__.extend([_s for _s in _names_with_underscore]) footer_text_map[dest_module] = _DEPRECATION_FOOTER % ( dest_module, public_apis_name, deprecation, has_lite) - return module_text_map, footer_text_map + return module_text_map, footer_text_map, root_module_footer def format_import(self, source_module_name, source_name, dest_name): """Formats import statement. @@ -620,7 +621,11 @@ def create_api_files(output_files, packages, root_init_template, output_dir, os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - module_text_map, deprecation_footer_map = get_api_init_text( + ( + module_text_map, + deprecation_footer_map, + root_module_footer, + ) = get_api_init_text( packages, output_package, api_name, api_version, compat_api_versions, lazy_loading, use_relative_imports) @@ -652,6 +657,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir, with open(root_init_template, 'r') as root_init_template_file: contents = root_init_template_file.read() contents = contents.replace('# API IMPORTS PLACEHOLDER', text) + contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer) elif module in compat_module_to_template: # Read base init file for compat module with open(compat_module_to_template[module], 'r') as init_template_file: diff --git a/tensorflow/virtual_root_template_v1.__init__.py b/tensorflow/virtual_root_template_v1.__init__.py index 236e9f52258..9a45bc0355d 100644 --- a/tensorflow/virtual_root_template_v1.__init__.py +++ b/tensorflow/virtual_root_template_v1.__init__.py @@ -132,7 +132,4 @@ try: except NameError: pass -# Manually patch keras and estimator so tf.keras and tf.estimator work -keras = _sys.modules["tensorflow.keras"] -if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] # LINT.ThenChange(//tensorflow/virtual_root_template_v2.__init__.py.oss) diff --git a/tensorflow/virtual_root_template_v2.__init__.py b/tensorflow/virtual_root_template_v2.__init__.py index 83c020182a8..bd8c903e455 100644 --- a/tensorflow/virtual_root_template_v2.__init__.py +++ b/tensorflow/virtual_root_template_v2.__init__.py @@ -126,14 +126,4 @@ try: except NameError: pass -# TODO(mihaimaruseac): Revisit all of this once we release 2.1 -# Manually patch keras and estimator so tf.keras and tf.estimator work -keras = _sys.modules["tensorflow.keras"] -if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] -# Also import module aliases -try: - from tensorflow_core import losses, metrics, initializers, optimizers -except ImportError: - pass - # LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss) From 4d9297306254d2584c79b08fb43a2eaf705e9771 Mon Sep 17 00:00:00 2001 From: Michal Tarnowski Date: Wed, 27 Nov 2019 13:26:19 +0100 Subject: [PATCH 029/279] Added comment on kNewtonSteps --- tensorflow/lite/kernels/internal/optimized/optimized_ops.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 2c09ca9b94f..9126a4b6797 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -2640,6 +2640,11 @@ inline void Div(const ArithmeticParams& params, int i = 0; const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); #ifdef USE_NEON + // NEON does not offer division instruction, multiplication by the reciprocal + // is used instead. This parameter controls the number of Newton-Raphson + // iterations used to refine the initial estimate of the reciprocal given by + // vrecpeq_f32 instruction. Typically, two iterations are enough to match + // the float division accuracy closely. static constexpr int kNewtonSteps = 2; static const auto TWO_F32 = vdupq_n_f32(2.f); const auto activation_min = vdupq_n_f32(output_activation_min); From d81dc812bdbcb665fa805b8ee4ce04bf0a2b572d Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Wed, 27 Nov 2019 06:27:24 -0800 Subject: [PATCH 030/279] Added comment to all-reduce simplifier. PiperOrigin-RevId: 282757340 Change-Id: I10f79c315dcd97cad2c28720f9797929d1d2ead0 --- tensorflow/compiler/xla/service/all_reduce_simplifier.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc index e541bfea11f..b3097b8ff77 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc @@ -34,6 +34,9 @@ StatusOr AllReduceSimplifier::Run(HloModule* module) { for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { if (!inst->shape().IsArray()) { // We currently do not change tuple-shaped all-reduce. + // Until XLA will support Token fed AllReduce(), the PyTorch client code + // uses a fake data token (constant) which relies on this pass to not + // optimize out (being fed within a tuple input). continue; } if (inst->IsCrossReplicaAllReduce() && From 957b33238c46d16feecc47895dc38ee728423256 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 27 Nov 2019 07:09:07 -0800 Subject: [PATCH 031/279] Remove unused private method: IntraProcessRendezvous::ParseKey(). PiperOrigin-RevId: 282762681 Change-Id: I233481030d5a937f4fd4c17c0e83f3356524a357 --- tensorflow/core/common_runtime/rendezvous_mgr.cc | 10 ---------- tensorflow/core/common_runtime/rendezvous_mgr.h | 6 ------ 2 files changed, 16 deletions(-) diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index d6fea8bd5d5..4d296252f69 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -50,16 +50,6 @@ Status IntraProcessRendezvous::Send(const ParsedKey& parsed, return local_->Send(parsed, args, val, is_dead); } -Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src, - Rendezvous::ParsedKey* parsed) { - { - mutex_lock l(mu_); - if (!status_.ok()) return status_; - } - TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed)); - return Status::OK(); -} - void IntraProcessRendezvous::SameWorkerRecvDone( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out, diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h index b4d8ab4eb2b..1f7e6f28aeb 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.h +++ b/tensorflow/core/common_runtime/rendezvous_mgr.h @@ -66,12 +66,6 @@ class IntraProcessRendezvous : public Rendezvous { ~IntraProcessRendezvous() override; - // Parses "key" into "parsed". If "is_src" is true, checks that the - // rendezvous key's source is in this process. If "is_src" is false, - // checks that the rendezvous key's destination is in this process. - Status ParseKey(const string& key, bool is_src, - Rendezvous::ParsedKey* parsed); - // Callback handling the case when a rendezvous has been // accomplished in local_ and the consumer is local to this process. // Tensor "in" will be copied into "out". The key "parsed" encodes From 81f844c1ff2bee0c3a98a7fff7b308ad77d85309 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 27 Nov 2019 07:20:21 -0800 Subject: [PATCH 032/279] Split Rendezvous class into pure-virtual RendezvousInterface and refcounted Rendezvous. This change lays the groundwork for creating non-refcounted RendezvousInterface implementations, which would allow us to avoid dynamic allocation and atomic refcount operations in some cases. It modifies internal classes that use Rendezvous* to use RendezvousInterface* instead: the change is safe because none of these rely on the ability to modify the rendezvous' refcount (and it is unlikely that it would be safe for them to do so). PiperOrigin-RevId: 282764107 Change-Id: I8ef6fe995962dfa6556ae066f990c6445462a13e --- tensorflow/core/common_runtime/executor.cc | 2 +- tensorflow/core/common_runtime/executor.h | 2 +- tensorflow/core/common_runtime/function.cc | 2 +- .../core/common_runtime/function_test.cc | 5 +-- .../core/common_runtime/graph_runner.cc | 13 ++++--- .../process_function_library_runtime.cc | 7 ++-- .../process_function_library_runtime.h | 4 +-- .../core/common_runtime/rendezvous_util.cc | 7 ++-- .../core/common_runtime/rendezvous_util.h | 7 ++-- tensorflow/core/framework/function.h | 2 +- tensorflow/core/framework/op_kernel.h | 4 +-- tensorflow/core/framework/rendezvous.cc | 10 +++--- tensorflow/core/framework/rendezvous.h | 34 ++++++++++++------- 13 files changed, 55 insertions(+), 44 deletions(-) diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index e1135ec488c..1c04adf7872 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1287,7 +1287,7 @@ class ExecutorState { int64 step_id_; // Not owned. - Rendezvous* rendezvous_; + RendezvousInterface* rendezvous_; Executor::RendezvousFactory* create_rendezvous_ = nullptr; CollectiveExecutor* collective_executor_ = nullptr; SessionState* session_state_; diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 42d5b9eab4f..c147deee694 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -88,7 +88,7 @@ class Executor { struct Args { int64 step_id = 0; - Rendezvous* rendezvous = nullptr; + RendezvousInterface* rendezvous = nullptr; StepStatsCollectorInterface* stats_collector = nullptr; CallFrameInterface* call_frame = nullptr; CancellationManager* cancellation_manager = nullptr; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 668a671d5e8..aa3be38fd29 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -1017,7 +1017,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, Item* item, DoneCallback done) { string target_device = parent_->GetDeviceName(handle); string source_device = opts.source_device; - Rendezvous* rendezvous = opts.rendezvous; + RendezvousInterface* rendezvous = opts.rendezvous; DeviceContext* device_context; Status s = parent_->GetDeviceContext(target_device, &device_context); if (!s.ok()) { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index c3d6e948f1e..7c76c469d1e 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -1854,7 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { Tensor y; FunctionLibraryRuntime::Options opts; - opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get()); + Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_.get()); + opts.rendezvous = rendezvous; opts.source_device = "/device:CPU:1"; // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true)); @@ -1869,7 +1870,7 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { y, test::AsTensor({"/job:localhost/replica:0/task:0/device:CPU:1"}, TensorShape({}))); - opts.rendezvous->Unref(); + rendezvous->Unref(); } namespace { diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 48ee4b11a33..0a7d50f9ea4 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -45,7 +45,7 @@ namespace { // A simple rendezvous class. // Assumes a single sender and a single receiver, no duplicate sends, and no // sends of dead tensors. -class SimpleRendezvous : public Rendezvous { +class SimpleRendezvous : public RendezvousInterface { public: explicit SimpleRendezvous() {} @@ -124,8 +124,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, std::unique_ptr graph_to_run(new Graph(graph->op_registry())); CopyGraph(*graph, graph_to_run.get()); - SimpleRendezvous* rendez = new SimpleRendezvous; - core::ScopedUnref rendez_unref(rendez); + SimpleRendezvous rendez; // Extract the input names and keys, and feed in the inputs. std::vector input_names; @@ -136,8 +135,8 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, tensor_name, FrameAndIter(0, 0)); Rendezvous::ParsedKey parsed; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(full_key, &parsed)); - TF_RETURN_IF_ERROR(rendez->Send(parsed, Rendezvous::Args(), in.second, - false /* is_dead */)); + TF_RETURN_IF_ERROR(rendez.Send(parsed, Rendezvous::Args(), in.second, + false /* is_dead */)); } // Call RewriteGraphForExecution @@ -180,7 +179,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, // called via this method. args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID; args.runner = runner; - args.rendezvous = rendez; + args.rendezvous = &rendez; // NOTE: Use of graph runner is limited to single-device executions // so a CollectiveExecutor should never be required. args.collective_executor = nullptr; @@ -201,7 +200,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, bool is_dead; Tensor output_tensor; TF_RETURN_IF_ERROR( - rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead)); + rendez.Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead)); // Does a deep copy so that ownership of the tensor isn't tied to the // allocator of the cpu device we created above. The allocator could be // deleted along with the device. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 5c08d330ccf..4c01978e6d5 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -122,7 +122,7 @@ Status ProcessFunctionLibraryRuntime::SendTensors( const string& key_prefix, int64 src_incarnation, gtl::ArraySlice tensors_to_send, DeviceContext* device_context, const std::vector& alloc_attrs, - Rendezvous* rendezvous) { + RendezvousInterface* rendezvous) { std::vector keys; for (int i = 0; i < tensors_to_send.size(); ++i) { string name = strings::StrCat(key_prefix, i); @@ -140,8 +140,9 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( const string& source_device, const string& target_device, const string& key_prefix, int64 src_incarnation, int64 num_tensors, DeviceContext* device_context, - const std::vector& alloc_attrs, Rendezvous* rendezvous, - std::vector* received_tensors, StatusCallback done) { + const std::vector& alloc_attrs, + RendezvousInterface* rendezvous, std::vector* received_tensors, + StatusCallback done) { std::vector keys; for (int64 i = 0; i < num_tensors; ++i) { string name = strings::StrCat(key_prefix, i); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 0166267b3ab..ee5d8bf2b16 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -92,7 +92,7 @@ class ProcessFunctionLibraryRuntime { gtl::ArraySlice tensors_to_send, DeviceContext* device_context, const std::vector& alloc_attrs, - Rendezvous* rendezvous); + RendezvousInterface* rendezvous); // Receives `received_tensors` from `target_device` (originally sent from // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the @@ -105,7 +105,7 @@ class ProcessFunctionLibraryRuntime { const string& key_prefix, int64 src_incarnation, int64 num_tensors, DeviceContext* device_context, const std::vector& alloc_attrs, - Rendezvous* rendezvous, std::vector* received_tensors, + RendezvousInterface* rendezvous, std::vector* received_tensors, StatusCallback done); static const char kDefaultFLRDevice[]; diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc index 43ca3f1e3e0..df3e9a2452d 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.cc +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { Status SendTensorsToRendezvous( - Rendezvous* rendezvous, DeviceContext* device_context, + RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, gtl::ArraySlice tensors_to_send) { if (keys.size() != tensors_to_send.size()) { @@ -54,7 +54,7 @@ Status SendTensorsToRendezvous( } void RecvOutputsFromRendezvousAsync( - Rendezvous* rendezvous, DeviceContext* device_context, + RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, std::vector* received_tensors, StatusCallback done) { @@ -118,7 +118,8 @@ void RecvOutputsFromRendezvousAsync( status_cb->Unref(); } -Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, +Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, + NamedTensors* out, const Rendezvous::Args& args) { // Receives values requested by the caller. Rendezvous::ParsedKey parsed; diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h index deb9a7c8225..fe95dc0ef57 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.h +++ b/tensorflow/core/common_runtime/rendezvous_util.h @@ -31,7 +31,7 @@ typedef std::function StatusCallback; // allocated. `alloc_attrs` should either be {} or should match the length of // `keys`. Status SendTensorsToRendezvous( - Rendezvous* rendezvous, DeviceContext* device_context, + RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, gtl::ArraySlice tensors_to_send); @@ -40,12 +40,13 @@ Status SendTensorsToRendezvous( // information as how to store the received tensors. Should be {} or match the // length of `keys`. void RecvOutputsFromRendezvousAsync( - Rendezvous* rendezvous, DeviceContext* device_context, + RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, std::vector* received_tensors, StatusCallback done); -Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, +Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, + NamedTensors* out, const Rendezvous::Args& args); } // namespace tensorflow diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 09a94d8b550..0e260d26592 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -687,7 +687,7 @@ class FunctionLibraryRuntime { // tensors to the remote TensorHandles in the default device. absl::optional op_id = absl::nullopt; - Rendezvous* rendezvous = nullptr; + RendezvousInterface* rendezvous = nullptr; CancellationManager* cancellation_manager = nullptr; CollectiveExecutor* collective_executor = nullptr; ScopedStepContainer* step_container = nullptr; diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 8372359e7ae..149667a9965 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -672,7 +672,7 @@ class OpKernelContext { // Mechanism used by this op kernel invocation to communicate with // computations running on other devices. - Rendezvous* rendezvous = nullptr; + RendezvousInterface* rendezvous = nullptr; const std::function* create_rendezvous; @@ -1100,7 +1100,7 @@ class OpKernelContext { // // An op kernel communicates with outside environment through // Rendezvous Send() and Recv(). - Rendezvous* rendezvous() const { return params_->rendezvous; } + RendezvousInterface* rendezvous() const { return params_->rendezvous; } Status create_rendezvous(const int64 step_id, const DeviceMgr* device_mgr, Rendezvous** r) const { return (*params_->create_rendezvous)(step_id, device_mgr, r); diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index ad3cf912d23..18be6238225 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -113,10 +113,10 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { return errors::InvalidArgument("Invalid rendezvous key: ", key); } -Rendezvous::~Rendezvous() {} +RendezvousInterface::~RendezvousInterface() {} -Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args, - Tensor* val, bool* is_dead, int64 timeout_ms) { +Status RendezvousInterface::Recv(const ParsedKey& key, const Args& recv_args, + Tensor* val, bool* is_dead, int64 timeout_ms) { Status ret; Notification n; RecvAsync(key, recv_args, @@ -141,8 +141,8 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args, return ret; } -Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val, - bool* is_dead) { +Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args, + Tensor* val, bool* is_dead) { const int64 no_timeout = 0; return Recv(key, args, val, is_dead, no_timeout); } diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index d6e910da991..b9172f63df6 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ -#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ #include @@ -44,7 +44,7 @@ namespace tensorflow { // been produced. A consumer has the choice of making a blocking call // or providing a callback: in either case, the consumer receives the // Tensor as soon as it is available. A producer never blocks. -class Rendezvous : public core::RefCounted { +class RendezvousInterface { public: struct Args { DeviceContext* device_context = nullptr; @@ -52,13 +52,6 @@ class Rendezvous : public core::RefCounted { CancellationManager* cancellation_manager = nullptr; // not owned. }; - // Constructs a rendezvous key for the tensor of "name" sent from - // "src_device" to "dst_device". The tensor is generated in the frame - // and iteration specified by "frame_iter". - static string CreateKey(const string& src_device, uint64 src_incarnation, - const string& dst_device, const string& name, - const FrameAndIter& frame_iter); - // Parses the key constructed by CreateKey and parse src/dst device // names into structures respectively. struct ParsedKey { @@ -81,7 +74,6 @@ class Rendezvous : public core::RefCounted { friend class RecvOp; string buf_; }; - static Status ParseKey(StringPiece key, ParsedKey* out); // The caller is a tensor producer and it sends a message (a tensor // "val" and a bool "is_dead") under the given "key". @@ -123,12 +115,28 @@ class Rendezvous : public core::RefCounted { virtual void StartAbort(const Status& status) = 0; protected: - ~Rendezvous() override; + virtual ~RendezvousInterface(); virtual bool is_cross_process() { return false; } friend class ProcessFunctionLibraryRuntime; }; +// A reference-counted implementation of RendezvousInterface. +// +// This class is used in cases where a rendezvous may be shared between multiple +// threads with no clear owner. +class Rendezvous : public RendezvousInterface, public core::RefCounted { + public: + // Constructs a rendezvous key for the tensor of "name" sent from + // "src_device" to "dst_device". The tensor is generated in the frame + // and iteration specified by "frame_iter". + static string CreateKey(const string& src_device, uint64 src_incarnation, + const string& dst_device, const string& name, + const FrameAndIter& frame_iter); + + static Status ParseKey(StringPiece key, ParsedKey* out); +}; + // Returns a Rendezvous instance that is limited to use only by // producers and consumers in the local process. The caller assumes // ownership of one Ref() on the returned object. @@ -136,4 +144,4 @@ Rendezvous* NewLocalRendezvous(); } // end namespace tensorflow -#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ From bb00fa819e104d81866d1f4214c5a6ac10a5839a Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 27 Nov 2019 07:31:41 -0800 Subject: [PATCH 033/279] Implement Linalg to loops lowering as a pattern This CL rewrites the linalg ops to loops transformations as patterns that can be targeted directly from Tablegen. Reliance on OpFolder is removed and to cope with it we introduce local folding patterns that are applied greedily. PiperOrigin-RevId: 282765550 Change-Id: I1cb9dd53a0364d965411b43c0ef1b52837e6af4a --- .../xla/service/mlir_gpu/kernel_lowering.cc | 2 +- third_party/mlir/BUILD | 2 +- .../mlir/include/mlir/Dialect/Linalg/Passes.h | 11 +- .../Transforms/LinalgTransformPatterns.td | 11 + .../Linalg/Transforms/LinalgTransforms.h | 41 ++- .../mlir/lib/Dialect/Linalg/CMakeLists.txt | 2 +- .../{LowerToLoops.cpp => LinalgToLoops.cpp} | 314 ++++++++++++------ .../TestLinalgTransformPatterns.td | 7 + 8 files changed, 261 insertions(+), 129 deletions(-) rename third_party/mlir/lib/Dialect/Linalg/Transforms/{LowerToLoops.cpp => LinalgToLoops.cpp} (64%) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index ea887926338..7cbbb3ec44e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -46,7 +46,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) { // Transform element-wise operations to LinAlg. pm.addPass(::mlir::xla_lhlo::createLegalizeToLinalgPass()); // Go from affine to normal loops. - pm.addPass(::mlir::linalg::createLowerLinalgToLoopsPass()); + pm.addPass(::mlir::linalg::createConvertLinalgToLoopsPass()); // Lower affine to ordinary loops. pm.addPass(::mlir::createLowerAffinePass()); // Move constants out of the loop. diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 3b09b3cb470..57893543f6f 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -2212,8 +2212,8 @@ cc_library( "lib/Dialect/Linalg/IR/LinalgOps.cpp", "lib/Dialect/Linalg/IR/LinalgTypes.cpp", "lib/Dialect/Linalg/Transforms/Fusion.cpp", + "lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp", "lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp", - "lib/Dialect/Linalg/Transforms/LowerToLoops.cpp", "lib/Dialect/Linalg/Transforms/Promotion.cpp", "lib/Dialect/Linalg/Transforms/Tiling.cpp", "lib/Dialect/Linalg/Utils/Utils.cpp", diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h index 5ecd50070da..7ae3877f01e 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -39,9 +39,16 @@ createLinalgTilingPass(ArrayRef tileSizes = {}); std::unique_ptr> createLinalgPromotionPass(bool dynamicBuffers); -std::unique_ptr> createLowerLinalgToLoopsPass(); +/// Create a pass to convert Linalg operations to loop.for loops and +/// std.load/std.store accesses. +std::unique_ptr> createConvertLinalgToLoopsPass(); -/// Create a pass to convert vector operations to the LLVMIR dialect. +/// Create a pass to convert Linalg operations to affine.for loops and +/// affine_load/affine_store accesses. +/// Placeholder for now, this is NYI. +std::unique_ptr> createConvertLinalgToAffineLoopsPass(); + +/// Create a pass to convert Linalg operations to the LLVMIR dialect. std::unique_ptr> createConvertLinalgToLLVMPass(); } // namespace linalg diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index d243bb23f2c..8bc0eaf2097 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -62,4 +62,15 @@ class TileLinalgOp sizes, string value> : NativeCodeCall< StrJoinInt.result # "}, \"" # value # "\")))" # " return matchFailure();">; +//===----------------------------------------------------------------------===// +// Linalg to loop patterns. +//===----------------------------------------------------------------------===// +class LinalgOpToLoops : NativeCodeCall< + "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " # + " return matchFailure();">; + +class LinalgOpToAffineLoops : NativeCodeCall< + "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " # + " return matchFailure();">; + #endif // LINALG_TRANSFORMS diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index 56ae94f32c6..966b8f93135 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -35,20 +35,6 @@ struct LinalgTransforms { static const StringLiteral kLinalgTransformMarker; }; -// Declarative transformation used in tablegen patterns. -// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to -// `linalgMarker`. -LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, - ArrayRef sizes, - StringRef linalgMarker); - -// Declarative transformation used in tablegen patterns. -// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets -// the attribute `kLinalgTransformMarker` to `linalgMarker`. -LogicalResult tileAndFuseLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker); - namespace detail { // Implementation detail of isProducedByOpOfType avoids the need for explicit // template instantiations. @@ -65,6 +51,33 @@ bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) { consumerOp, consumedView, [](Operation *op) { return isa(op); }); } +//////////////////////////////////////////////////////////////////////////////// +// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite +// patterns. As such, they must not call into `rewriter.erase/replace` APIs and +// it is the responsibility of the enclosing PatternRewriter to erase on +// success. +//////////////////////////////////////////////////////////////////////////////// + +// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to +// `linalgMarker`. +LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, + ArrayRef sizes, + StringRef linalgMarker); + +// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets +// the attribute `kLinalgTransformMarker` to `linalgMarker`. +LogicalResult tileAndFuseLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker); + +// Emits a loop nest of `loop.for` with the proper body for `op`. +template +LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op); + +// Emits a loop nest of `affine.for` with the proper body for `op`. +template +LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); + } // namespace linalg } // namespace mlir diff --git a/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt b/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt index 4b7cd81be94..a4ce5038891 100644 --- a/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -5,7 +5,7 @@ add_llvm_library(MLIRLinalg IR/LinalgTypes.cpp Transforms/Fusion.cpp Transforms/LinalgTransforms.cpp - Transforms/LowerToLoops.cpp + Transforms/LinalgToLoops.cpp Transforms/Promotion.cpp Transforms/Tiling.cpp Utils/Utils.cpp diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp similarity index 64% rename from third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp rename to third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index 0bf4ceaa33b..cf0b235f57f 100644 --- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp +++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" @@ -41,12 +42,14 @@ using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using IndexedStdValue = TemplatedIndexedValue; +using IndexedAffineValue = TemplatedIndexedValue; + using edsc::op::operator+; using edsc::op::operator==; static SmallVector -foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, - ArrayRef vals, OperationFolder *folder) { +makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, + ArrayRef vals) { assert(map.getNumSymbols() == 0); assert(map.getNumInputs() == vals.size()); SmallVector res; @@ -56,17 +59,16 @@ foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, auto exprMap = AffineMap::get(dims, 0, e); SmallVector operands(vals.begin(), vals.end()); canonicalizeMapAndOperands(&exprMap, &operands); - res.push_back(affine_apply(folder, exprMap, operands)); + res.push_back(affine_apply(exprMap, operands)); } return res; } static SmallVector permuteIvs(ArrayRef ivs, - Optional permutation, - OperationFolder *folder) { + Optional permutation) { return permutation ? applyMapToValues(ScopedContext::getBuilder(), ScopedContext::getLocation(), - permutation.getValue(), ivs, folder) + permutation.getValue(), ivs) : SmallVector(ivs.begin(), ivs.end()); } @@ -75,20 +77,17 @@ static SmallVector permuteIvs(ArrayRef ivs, // which new loops will be created. static SmallVector emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes, - OperationFolder *folder); + ArrayRef allViewSizes); SmallVector emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes, - OperationFolder *folder) { + ArrayRef allViewSizes) { // Apply `map` to get view sizes in loop order. - auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder); + auto sizes = applyMapToValues(b, loc, map, allViewSizes); // Create a new range with the applied tile sizes. ScopedContext scope(b, loc); SmallVector res; for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { - res.push_back(range(constant_index(folder, 0), sizes[idx], - constant_index(folder, 1))); + res.push_back(range(constant_index(0), sizes[idx], constant_index(1))); } return res; } @@ -99,14 +98,14 @@ class LinalgScopedEmitter {}; template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef allIvs, + CopyOp copyOp) { auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder); + permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); auto outputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder); + permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector iivs(inputIvs.begin(), inputIvs.end()); SmallVector oivs(outputIvs.begin(), outputIvs.end()); IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); @@ -122,8 +121,8 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef allIvs, + FillOp fillOp) { auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = @@ -139,8 +138,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), @@ -154,8 +152,7 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, - MatvecOp matvecOp, - OperationFolder *folder) { + MatvecOp matvecOp) { assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), @@ -169,8 +166,7 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, - MatmulOp matmulOp, - OperationFolder *folder) { + MatmulOp matmulOp) { assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), @@ -183,17 +179,17 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef allIvs, + ConvOp convOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); SmallVector fIdx( - foldedAffineApplies(b, loc, maps[0], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); SmallVector imIdx( - foldedAffineApplies(b, loc, maps[1], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); SmallVector oIdx( - foldedAffineApplies(b, loc, maps[2], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output()); // Emit scalar form. O(oIdx) += F(fIdx) * I(imIdx); @@ -234,8 +230,7 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, - GenericOp genericOp, - OperationFolder *folder) { + GenericOp genericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -245,15 +240,15 @@ public: // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs)); indexedValues[i] = std_load(genericOp.getInput(i), indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); } @@ -265,8 +260,8 @@ public: // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); } return; @@ -288,8 +283,8 @@ public: auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), indexing); } @@ -330,8 +325,7 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, - IndexedGenericOp indexedGenericOp, - OperationFolder *folder) { + IndexedGenericOp indexedGenericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -346,16 +340,16 @@ public: // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); indexedValues[nLoops + i] = std_load(indexedGenericOp.getInput(i), indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = std_load(indexedGenericOp.getOutput(i), indexing); } @@ -367,8 +361,8 @@ public: // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), indexing); } @@ -391,96 +385,110 @@ public: auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), indexedGenericOp.getOutput(i), indexing); } } }; +namespace { +// This struct is for factoring out the implementation and support template +// instantiations in the following 2 cases: +// 1. Appending to a list of patterns via RewritePatternList. +// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. +// The implementation must work both in DRR and inside a RewritePattern. As a +// consequence, (1) it is only allowed to emit new ops if the match is +// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an +// encompassing pattern must take care of the erasure logic. +template +class LinalgOpToLoopsImpl { +public: + static LogicalResult doit(Operation *op, PatternRewriter &rewriter); +}; +} // namespace + +template +LogicalResult LinalgOpToLoopsImpl::doit( + Operation *op, PatternRewriter &rewriter) { + OpBuilder b(op); + ScopedContext scope(b, op->getLoc()); + + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (which is asserted in the inverse calculation). + auto linalgOp = cast(op); + auto invertedMap = + inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); + if (!invertedMap) { + LinalgScopedEmitter::emitScalarImplementation( + {}, linalgOp); + return success(); + } + + auto nPar = linalgOp.getNumParallelLoops(); + auto nRed = linalgOp.getNumReductionLoops(); + auto nWin = linalgOp.getNumWindowLoops(); + SmallVector allIvs(nPar + nRed + nWin); + SmallVector allPIvs = makeIndexHandlePointers(allIvs); + auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(), + invertedMap, getViewSizes(linalgOp)); + assert(loopRanges.size() == allIvs.size()); + + LoopNestRangeBuilder(allPIvs, loopRanges)([&] { + auto allIvValues = extractValues(allIvs); + LinalgScopedEmitter::emitScalarImplementation( + allIvValues, linalgOp); + }); + return success(); +} + template class LinalgRewritePattern : public RewritePattern { public: explicit LinalgRewritePattern(MLIRContext *context) - : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context), - folder(context) {} + : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - OpBuilder b(op); - ScopedContext scope(b, op->getLoc()); - - // The flattened loopToOperandRangesMaps is expected to be an invertible - // permutation map (which is asserted in the inverse calculation). - auto linalgOp = cast(op); - auto invertedMap = - inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); - if (!invertedMap) { - LinalgScopedEmitter::emitScalarImplementation({}, linalgOp, - &folder); - rewriter.eraseOp(op); - return matchSuccess(); - } - - auto nPar = linalgOp.getNumParallelLoops(); - auto nRed = linalgOp.getNumReductionLoops(); - auto nWin = linalgOp.getNumWindowLoops(); - SmallVector allIvs(nPar + nRed + nWin); - SmallVector allPIvs = makeIndexHandlePointers(allIvs); - auto loopRanges = - emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap, - getViewSizes(linalgOp), &folder); - assert(loopRanges.size() == allIvs.size()); - - // clang-format off; - LoopNestRangeBuilder(allPIvs, loopRanges)([&] { - auto allIvValues = extractValues(allIvs); - LinalgScopedEmitter::emitScalarImplementation(allIvValues, - linalgOp, - &folder); - }); - // clang-format on + using Impl = LinalgOpToLoopsImpl; + if (failed(Impl::doit(op, rewriter))) + return matchFailure(); rewriter.eraseOp(op); return matchSuccess(); } - - mutable OperationFolder folder; }; // Helper classes for type list expansion. template -class ConversionList; +class RewritePatternList; template -class ConversionList { +class RewritePatternList { public: static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} }; template -class ConversionList { +class RewritePatternList { public: static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns .insert>( ctx); - ConversionList::build(patterns, - ctx); + RewritePatternList::build( + patterns, ctx); } }; /// Populate the given list with patterns that convert from Linalg to LLVM. template -void ForOpRewritePatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - ConversionList::build(patterns, ctx); + >::build(patterns, ctx); } namespace { @@ -491,28 +499,114 @@ struct LowerLinalgToLoopsPass }; } // namespace +// Local folding pattern for AffineApplyOp that we can apply greedily. +// This replaces AffineApplyOp by the proper value in cases where the associated +// map is trivial. A trivial map here is defined as a map with a single result +// and either: +// 1. Zero operand + returns a single AffineConstantExpr +// 2. One operand + returns a single AffineDimExpr +// 3. One operands + returns a single AffineSymbolExpr +// +// In the first case, the AffineApplyOp is replaced by a new constant. In the +// other cases, it is replaced by its unique operand. +struct FoldAffineOp : public RewritePattern { + FoldAffineOp(MLIRContext *context) + : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + AffineApplyOp affineApplyOp = cast(op); + auto map = affineApplyOp.getAffineMap(); + if (map.getNumResults() != 1 || map.getNumInputs() > 1) + return matchFailure(); + + AffineExpr expr = map.getResult(0); + if (map.getNumInputs() == 0) { + if (auto val = expr.dyn_cast()) { + rewriter.replaceOpWithNewOp(op, val.getValue()); + return matchSuccess(); + } + return matchFailure(); + } + if (expr.dyn_cast() || expr.dyn_cast()) { + rewriter.replaceOp(op, op->getOperand(0)); + return matchSuccess(); + } + return matchFailure(); + } +}; + template void LowerLinalgToLoopsPass::runOnFunction() { + auto *context = &this->getContext(); OwningRewritePatternList patterns; - ForOpRewritePatterns(patterns, - &this->getContext()); - - ConversionTarget target(this->getContext()); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - if (failed(applyPartialConversion(this->getFunction(), target, patterns))) { - this->signalPassFailure(); - } + // Canonicalization and folding patterns applied greedily allow cleaning up + // the emitted IR on the fly. + // TODO(ntv) fold view and subview ops? + FillRewritePatterns(patterns, context); + DimOp::getCanonicalizationPatterns(patterns, context); + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + patterns.insert(context); + // Just apply the patterns greedily. + applyPatternsGreedily(this->getFunction(), patterns); } +/// Create a pass to convert Linalg operations to loop.for loops and +/// std.load/std.store accesses. std::unique_ptr> -mlir::linalg::createLowerLinalgToLoopsPass() { +mlir::linalg::createConvertLinalgToLoopsPass() { return std::make_unique< LowerLinalgToLoopsPass>(); } +/// Create a pass to convert Linalg operations to affine.for loops and +/// affine_load/affine_store accesses. +/// Placeholder for now, this is NYI. +std::unique_ptr> +mlir::linalg::createConvertLinalgToAffineLoopsPass() { + return std::make_unique< + LowerLinalgToLoopsPass>(); +} + +// Emits a loop nest of `loop.for` with the proper body for `op`. +template +LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl::doit( + op, rewriter); +} + +// Emits a loop nest of `affine.for` with the proper body for `op`. +template +LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl::doit( + op, rewriter); +} + +// TODO(ntv) Need to make these instantiations more future-proof to avoid the +// need to update as soon as we add new ops. +#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ + template LogicalResult mlir::linalg::linalgOpToLoops( \ + PatternRewriter & rewriter, Operation * op); \ + template LogicalResult mlir::linalg::linalgOpToAffineLoops( \ + PatternRewriter & rewriter, Operation * op); + +INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) + static PassRegistration> structuredLoopsPass( - "linalg-lower-to-loops", + "convert-linalg-to-loops", "Lower the operations from the linalg dialect into loops"); + +static PassRegistration> + affineLoopsPass( + "convert-linalg-to-affine-loops", + "Lower the operations from the linalg dialect into affine loops"); diff --git a/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index 839671c866a..97e0cb21704 100644 --- a/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -73,4 +73,11 @@ def : Pattern<(DotOp:$op $a, $b, $c), [(TileLinalgOp<[8], "REG"> $op)], [(Constraint> $op)]>; +//===----------------------------------------------------------------------===// +// Linalg to loops patterns. +//===----------------------------------------------------------------------===// +def : Pattern<(DotOp:$op $a, $b, $c), + [(LinalgOpToLoops<"DotOp"> $op)], + [(Constraint> $op)]>; + #endif // TEST_LINALG_TRANSFORMS_PATTERNS From 31529dc7478f1b628ccabd3db7b7fb399f8cd206 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 27 Nov 2019 07:32:31 -0800 Subject: [PATCH 034/279] Automated g4 rollback of changelist 279340363. *** Reason for rollback *** This cl breaks Windows build with --noincompatible_remove_legacy_whole_archive, which is needed in order to fix https://github.com/tensorflow/addons/issues/663 PiperOrigin-RevId: 282765684 Change-Id: Ia557a864d5d5c5b8ca8964ec5985114f7464511c --- tensorflow/core/platform/BUILD | 1 - tensorflow/core/platform/status.cc | 3 --- 2 files changed, 4 deletions(-) diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 4a09e0dfd1e..8c676bc16e2 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -473,7 +473,6 @@ cc_library( ":logging", ":macros", ":mutex", - ":stacktrace", ":str_util", ":strcat", ":stringpiece", diff --git a/tensorflow/core/platform/status.cc b/tensorflow/core/platform/status.cc index a7fd3e693a1..d9cd02a27fb 100644 --- a/tensorflow/core/platform/status.cc +++ b/tensorflow/core/platform/status.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/base/call_once.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/stringprintf.h" @@ -92,8 +91,6 @@ Status::Status(tensorflow::error::Code code, StringPiece msg) { state_ = std::unique_ptr(new State); state_->code = code; state_->msg = string(msg); - VLOG(5) << "Generated non-OK status: \"" << *this << "\". " - << CurrentStackTrace(); } void Status::Update(const Status& new_status) { From 6965b4d03d9107a86573a102f058ece3b48faac8 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 27 Nov 2019 07:41:00 -0800 Subject: [PATCH 035/279] Remove unused IntraProcessRendezvous::{mu_,status_}. The status was never written, so the mutex and Status objects are unnecessary. PiperOrigin-RevId: 282766872 Change-Id: I0148a65ebfbd66e89cd298a70dfe1e503481dcd7 --- tensorflow/core/common_runtime/rendezvous_mgr.cc | 5 ----- tensorflow/core/common_runtime/rendezvous_mgr.h | 5 ----- 2 files changed, 10 deletions(-) diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index 4d296252f69..cd697574663 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -41,11 +41,6 @@ Status IntraProcessRendezvous::Send(const ParsedKey& parsed, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) { VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey(); - { - mutex_lock l(mu_); - if (!status_.ok()) return status_; - } - // Buffers "val" and "device_context" in local_. return local_->Send(parsed, args, val, is_dead); } diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h index 1f7e6f28aeb..0602ac3e936 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.h +++ b/tensorflow/core/common_runtime/rendezvous_mgr.h @@ -59,11 +59,6 @@ class IntraProcessRendezvous : public Rendezvous { const DeviceMgr* device_mgr_; Rendezvous* local_; // Owns a Ref on this object. - mutable mutex mu_; - - // Status given by StartAbort() if any. - Status status_ GUARDED_BY(mu_); - ~IntraProcessRendezvous() override; // Callback handling the case when a rendezvous has been From 35349536e77ad44061a7bdb340af623a2fea5f41 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 27 Nov 2019 17:07:10 +0000 Subject: [PATCH 036/279] [ROCm] Fix for the broken ROCm CSB - 191127 The following commit breaks the ROCm CSB https://github.com/tensorflow/tensorflow/commit/4b4bf4223d50cb28da44f081722520889b72583e It introduces new functionality (condition number method) and unit-tests to check the same. The unit-tests need support for the complex number types, which is currently not supported on the ROCm platform, and hence the consequent breakage. The "fix" is to skip testing the complex types when running the newly added unit-test on the ROCm platform --- tensorflow/python/ops/linalg/linear_operator_test_util.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index ae7b11778d4..33b8003ae2e 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -452,6 +452,11 @@ def _test_cond(use_placeholder, shapes_info, dtype): if 0 in shapes_info.shape[-2:]: return + # ROCm platform does not yet support complex types + if test.is_built_with_rocm() and \ + ((dtype == dtypes.complex64) or (dtype == dtypes.complex128)): + return + sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED # Ensure self-adjoint and PD so we get finite condition numbers. operator, mat = self.operator_and_matrix( From 0f2405bb659d2af0f39ee3c193bc264233e99450 Mon Sep 17 00:00:00 2001 From: Lucy Fox Date: Wed, 27 Nov 2019 09:32:27 -0800 Subject: [PATCH 037/279] Lower TF BitwiseOr to XLA HLO. PiperOrigin-RevId: 282784227 Change-Id: If82f013e4bbff3b8674a3691fbde2ba4b4ea5d7c --- .../mlir/tensorflow/ir/tf_generated_ops.td | 38 +++++++++++++++++++ .../compiler/mlir/xla/tests/legalize-tf.mlir | 21 ++++++++++ .../xla/transforms/legalize_tf_patterns.td | 7 ++-- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index d15cb91edca..b68634ba704 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -576,6 +576,44 @@ endian orderings will give different results. let hasCanonicalizer = 1; } +def TF_BitwiseOrOp : TF_Op<"BitwiseOr", [Broadcastable, Commutative, NoSideEffect]>, + WithBroadcastableBinOpBuilder { + let summary = "Elementwise computes the bitwise OR of `x` and `y`."; + + let description = [{ +The result will have those bits set, that are set in `x`, `y` or both. The +computation is performed on the underlying representations of `x` and `y`. + +For example: + +```python +import tensorflow as tf +from tensorflow.python.ops import bitwise_ops +dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64, + tf.uint8, tf.uint16, tf.uint32, tf.uint64] + +for dtype in dtype_list: + lhs = tf.constant([0, 5, 3, 14], dtype=dtype) + rhs = tf.constant([5, 0, 7, 11], dtype=dtype) + exp = tf.constant([5, 5, 7, 15], dtype=tf.float32) + + res = bitwise_ops.bitwise_or(lhs, rhs) + tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE +``` + }]; + + let arguments = (ins + TF_IntTensor:$x, + TF_IntTensor:$y + ); + + let results = (outs + TF_IntTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> { let summary = [{ Return the reduction indices for computing gradients of s0 op s1 with broadcast. diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 6c8737737ec..3004f2276fe 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -180,6 +180,27 @@ func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { return %0: tensor } +// CHECK-LABEL: func @bitwise_or +func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @bitwise_or_broadcast +func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + return %0: tensor<1x4xi8> +} + +// CHECK-LABEL: func @bitwise_or_dynamic +func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + return %0: tensor +} + // CHECK-LABEL: func @pow func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-NEXT: xla_hlo.pow diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 4bf7ee16d0a..ef11acab481 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -157,15 +157,16 @@ def : Pat<(TF_BroadcastToOp:$result AnyRankedTensor:$input, $shape), [(AnyRankedTensor $result)]>; //===----------------------------------------------------------------------===// -// Logical binary op patterns. +// Logical & bitwise binary op patterns. //===----------------------------------------------------------------------===// class DirectLogicalBinaryPat - : Pat<(FromOp I1Tensor:$l, I1Tensor:$r), + : Pat<(FromOp IntegerTensor:$l, IntegerTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], - [TF_LogicalOrOp, HLO_OrOp]] in + [TF_LogicalOrOp, HLO_OrOp], + [TF_BitwiseOrOp, HLO_OrOp]] in def : DirectLogicalBinaryPat; //===----------------------------------------------------------------------===// From 8390a35f36601fd7c24917c24ed116627431216d Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 27 Nov 2019 10:05:14 -0800 Subject: [PATCH 038/279] Disable failing test PiperOrigin-RevId: 282789932 Change-Id: Ic0f2fef91a56e60b544f72498c352dc3af410037 --- tensorflow/lite/experimental/ios/BUILD.apple | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple index 011d1725e3e..2ccb207f19b 100644 --- a/tensorflow/lite/experimental/ios/BUILD.apple +++ b/tensorflow/lite/experimental/ios/BUILD.apple @@ -76,6 +76,10 @@ cc_library( # Used for building TensorFlowLiteC framework. build_test( name = "framework_build_test", + tags = [ + "nomsan", # b/145205324 + "notsan", # b/145205324 + ], targets = [ ":TensorFlowLiteC_framework", ":TensorFlowLiteCWithSelectTfOps_framework", From 84e14224d32bbd90b8592ac864c9447b18320212 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 27 Nov 2019 10:23:51 -0800 Subject: [PATCH 039/279] [XLA] Use LOG(QFATAL) rather than LOG(FATAL) in XLA_FLAGS argument parsing. PiperOrigin-RevId: 282793211 Change-Id: I2177e3bfc90c425a34e9c2812ec31f3b1f5a649f --- tensorflow/compiler/xla/parse_flags_from_env.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index e1e22f78417..3a914c694dc 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -17,9 +17,12 @@ limitations under the License. // modules to parse flags from an environtment variable, or a file named by the // environment variable. +#include "tensorflow/compiler/xla/parse_flags_from_env.h" + #include #include #include + #include #include #include @@ -28,7 +31,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -218,9 +220,9 @@ bool ParseFlagsFromEnvAndDieIfUnknown( alternate_envvar); } - LOG(FATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "") - << " in " << envvar << ": " << absl::StrJoin(unknown_flags, " ") - << did_you_mean; + LOG(QFATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "") + << " in " << envvar << ": " << absl::StrJoin(unknown_flags, " ") + << did_you_mean; return false; } return result; From c2cf8890eccb23eaa0b1a41afd4d13d673b7a7d6 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Wed, 27 Nov 2019 10:30:42 -0800 Subject: [PATCH 040/279] Add tf.linalg.experimental.conjugate_gradient to python API. PiperOrigin-RevId: 282794568 Change-Id: Ibcc15794d14d0bb07816d1851e08997c7cfc655d --- ...pi_def_CSRSparseMatrixToSparseTensor.pbtxt | 1 + .../api_def_DenseToCSRSparseMatrix.pbtxt | 1 + .../base_api/api_def_SparseMatrixAdd.pbtxt | 1 + .../base_api/api_def_SparseMatrixMatMul.pbtxt | 1 + .../base_api/api_def_SparseMatrixMul.pbtxt | 1 + .../base_api/api_def_SparseMatrixNNZ.pbtxt | 1 + .../api_def_SparseMatrixOrderingAMD.pbtxt | 1 + .../api_def_SparseMatrixSoftmax.pbtxt | 1 + .../api_def_SparseMatrixSoftmaxGrad.pbtxt | 1 + .../api_def_SparseMatrixSparseCholesky.pbtxt | 1 + .../api_def_SparseMatrixSparseMatMul.pbtxt | 1 + .../api_def_SparseMatrixTranspose.pbtxt | 1 + .../base_api/api_def_SparseMatrixZeros.pbtxt | 1 + ...pi_def_SparseTensorToCSRSparseMatrix.pbtxt | 1 + tensorflow/python/__init__.py | 1 + .../ops/linalg/sparse/conjugate_gradient.py | 2 +- tensorflow/python/ops/linalg/sparse/sparse.py | 6 +- .../tools/api/generator/api_init_files.bzl | 1 + .../tools/api/generator/api_init_files_v1.bzl | 1 + .../v1/tensorflow.linalg.experimental.pbtxt | 7 ++ .../api/golden/v1/tensorflow.linalg.pbtxt | 4 ++ .../api/golden/v1/tensorflow.raw_ops.pbtxt | 64 +++++++++++++++++++ .../v2/tensorflow.linalg.experimental.pbtxt | 7 ++ .../api/golden/v2/tensorflow.linalg.pbtxt | 4 ++ .../api/golden/v2/tensorflow.raw_ops.pbtxt | 64 +++++++++++++++++++ 25 files changed, 173 insertions(+), 2 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.linalg.experimental.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.linalg.experimental.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_CSRSparseMatrixToSparseTensor.pbtxt b/tensorflow/core/api_def/base_api/api_def_CSRSparseMatrixToSparseTensor.pbtxt index 2b932378339..21d1983b4b3 100644 --- a/tensorflow/core/api_def/base_api/api_def_CSRSparseMatrixToSparseTensor.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_CSRSparseMatrixToSparseTensor.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "CSRSparseMatrixToSparseTensor" + visibility: HIDDEN in_arg { name: "sparse_matrix" description: "A (possibly batched) CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_DenseToCSRSparseMatrix.pbtxt b/tensorflow/core/api_def/base_api/api_def_DenseToCSRSparseMatrix.pbtxt index 9e578c0f123..23822cbf438 100644 --- a/tensorflow/core/api_def/base_api/api_def_DenseToCSRSparseMatrix.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DenseToCSRSparseMatrix.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DenseToCSRSparseMatrix" + visibility: HIDDEN in_arg { name: "dense_input" description: "A Dense tensor." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixAdd.pbtxt index 78c20141b67..58328c1941f 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixAdd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixAdd.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixAdd" + visibility: HIDDEN in_arg { name: "a" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixMatMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixMatMul.pbtxt index 8d4da45cd8a..679edada8f8 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixMatMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixMatMul.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixMatMul" + visibility: HIDDEN in_arg { name: "a" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixMul.pbtxt index 0f9a8b30351..aa7554f7104 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixMul.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixMul" + visibility: HIDDEN in_arg { name: "a" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixNNZ.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixNNZ.pbtxt index 7e19822a6d7..cc04b94c82a 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixNNZ.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixNNZ.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixNNZ" + visibility: HIDDEN in_arg { name: "sparse_matrix" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt index 32704f2cf33..9e7842c0f68 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixOrderingAMD" + visibility: HIDDEN in_arg { name: "input" description: "A `CSRSparseMatrix`." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmax.pbtxt index bf868e5ff5c..31d9aaf44b0 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmax.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixSoftmax" + visibility: HIDDEN in_arg { name: "logits" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmaxGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmaxGrad.pbtxt index bb7961b94fd..0705ec91132 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmaxGrad.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmaxGrad.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixSoftmaxGrad" + visibility: HIDDEN in_arg { name: "softmax" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt index e69814e9f91..f7cdd3574ac 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixSparseCholesky" + visibility: HIDDEN in_arg { name: "input" description: "A `CSRSparseMatrix`." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt index 8c9cc0ba151..f84b3948be4 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixSparseMatMul" + visibility: HIDDEN in_arg { name: "a" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixTranspose.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixTranspose.pbtxt index 5a3cfba8cce..179bb312ade 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixTranspose.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixTranspose.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixTranspose" + visibility: HIDDEN in_arg { name: "input" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixZeros.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixZeros.pbtxt index c535bba6876..08a5cc16e7d 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixZeros.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixZeros.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixZeros" + visibility: HIDDEN in_arg { name: "dense_shape" description: "The desired matrix shape." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseTensorToCSRSparseMatrix.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseTensorToCSRSparseMatrix.pbtxt index dc8c229056b..9deb28c61f5 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseTensorToCSRSparseMatrix.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseTensorToCSRSparseMatrix.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseTensorToCSRSparseMatrix" + visibility: HIDDEN in_arg { name: "indices" description: "SparseTensor indices." diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 10034a9ed65..7ba4d4278fc 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -106,6 +106,7 @@ from tensorflow.python.ops import sets from tensorflow.python.ops import stateful_random_ops from tensorflow.python.ops.distributions import distributions from tensorflow.python.ops.linalg import linalg +from tensorflow.python.ops.linalg.sparse import sparse from tensorflow.python.ops.losses import losses from tensorflow.python.ops.signal import signal from tensorflow.python.profiler import profiler diff --git a/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py b/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py index dd8798141db..613309f856d 100644 --- a/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py +++ b/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py @@ -30,7 +30,7 @@ from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.util.tf_export import tf_export -@tf_export('linalg.experimental.sparse.conjugate_gradient') +@tf_export('linalg.experimental.conjugate_gradient') def conjugate_gradient(operator, rhs, preconditioner=None, diff --git a/tensorflow/python/ops/linalg/sparse/sparse.py b/tensorflow/python/ops/linalg/sparse/sparse.py index ef7abdc6b81..6f9b2522335 100644 --- a/tensorflow/python/ops/linalg/sparse/sparse.py +++ b/tensorflow/python/ops/linalg/sparse/sparse.py @@ -20,7 +20,11 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.python.ops.linalg.sparse.conjugate_gradient import * +from tensorflow.python.ops.linalg.sparse.conjugate_gradient import conjugate_gradient from tensorflow.python.ops.linalg.sparse.sparse_csr_matrix_grad import * from tensorflow.python.ops.linalg.sparse.sparse_csr_matrix_ops import * # pylint: enable=wildcard-import + +__all__ = [ + 'conjugate_gradient' +] diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index b2981b14209..45c1a959256 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -32,6 +32,7 @@ TENSORFLOW_API_INIT_FILES = [ "io/__init__.py", "queue/__init__.py", "linalg/__init__.py", + "linalg/experimental/__init__.py", "lite/__init__.py", "lite/experimental/__init__.py", "lite/experimental/microfrontend/__init__.py", diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index 31e0c6ca457..a67afdcad29 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -36,6 +36,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "layers/__init__.py", "layers/experimental/__init__.py", "linalg/__init__.py", + "linalg/experimental/__init__.py", "lite/__init__.py", "lite/constants/__init__.py", "lite/experimental/__init__.py", diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.experimental.pbtxt new file mode 100644 index 00000000000..9c9a67ba712 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.experimental.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.linalg.experimental" +tf_module { + member_method { + name: "conjugate_gradient" + argspec: "args=[\'operator\', \'rhs\', \'preconditioner\', \'x\', \'tol\', \'max_iter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1e-05\', \'20\', \'conjugate_gradient\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt index f645db2f310..632400c6570 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt @@ -76,6 +76,10 @@ tf_module { name: "LinearOperatorZeros" mtype: "" } + member { + name: "experimental" + mtype: "" + } member_method { name: "adjoint" argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index f09b683f0f9..604f676bf34 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -604,6 +604,18 @@ tf_module { name: "BytesProducedStatsDataset" argspec: "args=[\'input_dataset\', \'tag\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "CSRSparseMatrixComponents" + argspec: "args=[\'csr_sparse_matrix\', \'index\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "CSRSparseMatrixToDense" + argspec: "args=[\'sparse_input\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "CSRSparseMatrixToSparseTensor" + argspec: "args=[\'sparse_matrix\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "CSVDataset" argspec: "args=[\'filenames\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\', \'record_defaults\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1032,6 +1044,10 @@ tf_module { name: "DeleteSessionTensor" argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DenseToCSRSparseMatrix" + argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DenseToDenseSetOperation" argspec: "args=[\'set1\', \'set2\', \'set_operation\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " @@ -3952,6 +3968,50 @@ tf_module { name: "SparseMatMul" argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } + member_method { + name: "SparseMatrixAdd" + argspec: "args=[\'a\', \'b\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixMatMul" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'transpose_output\', \'conjugate_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "SparseMatrixMul" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixNNZ" + argspec: "args=[\'sparse_matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixOrderingAMD" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSoftmax" + argspec: "args=[\'logits\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSoftmaxGrad" + argspec: "args=[\'softmax\', \'grad_softmax\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSparseCholesky" + argspec: "args=[\'input\', \'permutation\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSparseMatMul" + argspec: "args=[\'a\', \'b\', \'type\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "SparseMatrixTranspose" + argspec: "args=[\'input\', \'type\', \'conjugate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "SparseMatrixZeros" + argspec: "args=[\'dense_shape\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseReduceMax" argspec: "args=[\'input_indices\', \'input_values\', \'input_shape\', \'reduction_axes\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -4048,6 +4108,10 @@ tf_module { name: "SparseTensorSliceDataset" argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SparseTensorToCSRSparseMatrix" + argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseToDense" argspec: "args=[\'sparse_indices\', \'output_shape\', \'sparse_values\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.experimental.pbtxt new file mode 100644 index 00000000000..9c9a67ba712 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.experimental.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.linalg.experimental" +tf_module { + member_method { + name: "conjugate_gradient" + argspec: "args=[\'operator\', \'rhs\', \'preconditioner\', \'x\', \'tol\', \'max_iter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1e-05\', \'20\', \'conjugate_gradient\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt index a58c988577a..041041f60ed 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -76,6 +76,10 @@ tf_module { name: "LinearOperatorZeros" mtype: "" } + member { + name: "experimental" + mtype: "" + } member_method { name: "adjoint" argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index f09b683f0f9..604f676bf34 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -604,6 +604,18 @@ tf_module { name: "BytesProducedStatsDataset" argspec: "args=[\'input_dataset\', \'tag\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "CSRSparseMatrixComponents" + argspec: "args=[\'csr_sparse_matrix\', \'index\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "CSRSparseMatrixToDense" + argspec: "args=[\'sparse_input\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "CSRSparseMatrixToSparseTensor" + argspec: "args=[\'sparse_matrix\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "CSVDataset" argspec: "args=[\'filenames\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\', \'record_defaults\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1032,6 +1044,10 @@ tf_module { name: "DeleteSessionTensor" argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DenseToCSRSparseMatrix" + argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DenseToDenseSetOperation" argspec: "args=[\'set1\', \'set2\', \'set_operation\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " @@ -3952,6 +3968,50 @@ tf_module { name: "SparseMatMul" argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } + member_method { + name: "SparseMatrixAdd" + argspec: "args=[\'a\', \'b\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixMatMul" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'transpose_output\', \'conjugate_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "SparseMatrixMul" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixNNZ" + argspec: "args=[\'sparse_matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixOrderingAMD" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSoftmax" + argspec: "args=[\'logits\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSoftmaxGrad" + argspec: "args=[\'softmax\', \'grad_softmax\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSparseCholesky" + argspec: "args=[\'input\', \'permutation\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSparseMatMul" + argspec: "args=[\'a\', \'b\', \'type\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "SparseMatrixTranspose" + argspec: "args=[\'input\', \'type\', \'conjugate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "SparseMatrixZeros" + argspec: "args=[\'dense_shape\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseReduceMax" argspec: "args=[\'input_indices\', \'input_values\', \'input_shape\', \'reduction_axes\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -4048,6 +4108,10 @@ tf_module { name: "SparseTensorSliceDataset" argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SparseTensorToCSRSparseMatrix" + argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseToDense" argspec: "args=[\'sparse_indices\', \'output_shape\', \'sparse_values\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " From b98db48be1b8238e213d1d6b9ab607ea96088224 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Wed, 27 Nov 2019 18:35:45 +0000 Subject: [PATCH 041/279] Fix create_python_api_test.py --- .../tools/api/generator/create_python_api_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py index 010f189dcb2..76404d6c82b 100644 --- a/tensorflow/python/tools/api/generator/create_python_api_test.py +++ b/tensorflow/python/tools/api/generator/create_python_api_test.py @@ -62,7 +62,7 @@ class CreatePythonApiTest(test.TestCase): del sys.modules[_MODULE_NAME] def testFunctionImportIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -97,7 +97,7 @@ class CreatePythonApiTest(test.TestCase): msg='compat.v1 in %s' % str(imports.keys())) def testClassImportIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -116,7 +116,7 @@ class CreatePythonApiTest(test.TestCase): msg='%s not in %s' % (expected_import, str(imports))) def testConstantIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -132,7 +132,7 @@ class CreatePythonApiTest(test.TestCase): msg='%s not in %s' % (expected, str(imports))) def testCompatModuleIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -144,7 +144,7 @@ class CreatePythonApiTest(test.TestCase): msg='compat.v1.test not in %s' % str(imports.keys())) def testNestedCompatModulesAreAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', From 39029c26b932efdc503ba762b0616d5942400784 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Wed, 27 Nov 2019 10:35:12 -0800 Subject: [PATCH 042/279] Fix the MacOS convert_test. PiperOrigin-RevId: 282795396 Change-Id: I4207d1200e6326767a1879b627017916942a811e --- tensorflow/python/framework/dtypes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index a9a8ac0518a..6bcf71915c7 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -21,8 +21,11 @@ import numpy as np from six.moves import builtins from tensorflow.core.framework import types_pb2 -from tensorflow.python import _dtypes +# pywrap_tensorflow must be imported prior to _dtypes for the MacOS linker +# to resolve the protobufs properly. +# pylint: disable=unused-import,g-bad-import-order from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _dtypes from tensorflow.python.util.tf_export import tf_export _np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() From eab06047007be01dc7285499a7d4c1bff88af3a7 Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 27 Nov 2019 10:35:48 -0800 Subject: [PATCH 043/279] Faster NeonIsZeroVector. The integer version now uses 2 ALU and 1 branch branch instruction and no extra registers to test if a NEON register is all zeros. Previously it used 6 ALU and 4 branch instructions and 1 extra NEON register. The float version uses 3 ALU and 1 branch instructions, and one extra NEON register instead of 6 ALU and 1 branch instructions and no extra NEON registers. PiperOrigin-RevId: 282795511 Change-Id: I5e8fc32302a13ca6eb05f8a4b7396433d9dc3acc --- .../internal/optimized/neon_tensor_utils.cc | 49 +++++++++++++------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index b23e0305990..d5c1f227b9a 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -1970,6 +1970,37 @@ void NeonSub1Vector(const int16_t* vector, int v_size, int16_t* result) { } } +namespace { + +#if __aarch64__ +inline bool IsAllZero(const uint32x4_t u32x4) { + const uint32_t u32 = vmaxvq_u32(u32x4); + return !u32; +} +#else +inline bool IsAllZero(const uint32x4_t u32x4) { + const uint32x2_t u32x2 = vqadd_u32(vget_high_u32(u32x4), vget_low_u32(u32x4)); + const uint64x1_t u64 = vreinterpret_u64_u32(u32x2); + return !vget_lane_u64(u64, 0); +} +#endif + +#ifndef __SSE__ +// With Intel NEON-2-SSE translator library, this is a redefinition.. +inline bool IsAllZero(const int8x16_t v) { + return IsAllZero(vreinterpretq_u32_s8(v)); +} +#endif + +inline bool IsAllZero(const float32x4_t v_f32x4) { + const float32x4_t zero_f32x4 = vmovq_n_f32(0.0f); + // Compare-absolute greater-than, |v| > |0|, equivalently v != 0 + const uint32x4_t cmp_result = vcagtq_f32(v_f32x4, zero_f32x4); + return IsAllZero(cmp_result); +} + +} // namespace + bool NeonIsZeroVector(const float* vector, int v_size) { // If v_size is not divisible by the vector size, then we need to process the // final few elements sequentially. postamble_start shows the start index @@ -1977,15 +2008,10 @@ bool NeonIsZeroVector(const float* vector, int v_size) { const int postamble_start = RoundDownVectors(v_size); - const float32x4_t zero_x4_float = vmovq_n_f32(0.0f); int v = 0; for (; v < postamble_start; v += kFloatValuesPerNeonVector) { - const float32x4_t i_x4_float = vld1q_f32(vector + v); - uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float); - if (vgetq_lane_u32(cmp_result, 0) == 0) return false; - if (vgetq_lane_u32(cmp_result, 1) == 0) return false; - if (vgetq_lane_u32(cmp_result, 2) == 0) return false; - if (vgetq_lane_u32(cmp_result, 3) == 0) return false; + const float32x4_t v_f32x4 = vld1q_f32(vector + v); + if (!IsAllZero(v_f32x4)) return false; } // Postamble loop for (; v < v_size; ++v) { @@ -2001,15 +2027,10 @@ bool NeonIsZeroVector(const int8_t* vector, int v_size) { const int postamble_start = RoundDownVectors(v_size); - static const int32x4_t zero_x4_int32 = vmovq_n_s32(0); int v = 0; for (; v < postamble_start; v += kInt8ValuesPerNeonVector) { - const int32x4_t i_x4_int32 = vreinterpretq_s32_s8(vld1q_s8(vector + v)); - const uint32x4_t cmp_result = vceqq_s32(i_x4_int32, zero_x4_int32); - if (vgetq_lane_u32(cmp_result, 0) == 0) return false; - if (vgetq_lane_u32(cmp_result, 1) == 0) return false; - if (vgetq_lane_u32(cmp_result, 2) == 0) return false; - if (vgetq_lane_u32(cmp_result, 3) == 0) return false; + const int8x16_t v_s8x16 = vld1q_s8(vector + v); + if (!IsAllZero(v_s8x16)) return false; } // Postamble loop for (; v < v_size; ++v) { From 09bc59f39536b7f56c904f044f1583cfefd00669 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 27 Nov 2019 10:37:34 -0800 Subject: [PATCH 044/279] Add CPU benchmark for Conv2D backprop input PiperOrigin-RevId: 282795817 Change-Id: I0fca45e0bf0f62e4cb87b5e17bab020d45a5dfd5 --- tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc index ee4d2800ca7..938ef976ed8 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc @@ -158,4 +158,6 @@ BENCHMARK_DTYPE(NCHW, 64, fp16); #endif // GOOGLE_CUDA +BM_Conv2DBwdInputFmt(float, NHWC, 8, 32, 32, 128, 1, 1, 128, cpu); + } // namespace tensorflow From fb28210ab07463cd7a2a7428d5b4be28fc55a1ab Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Wed, 27 Nov 2019 10:49:05 -0800 Subject: [PATCH 045/279] Replace gemmlowp::ScopedProfilingLabel with portable wrapper. PiperOrigin-RevId: 282797995 Change-Id: I09c10f84223607fb906a9e4464841b65b50c07ae --- tensorflow/lite/kernels/internal/BUILD | 2 +- .../internal/reference/reference_ops.h | 81 +++++++++---------- 2 files changed, 40 insertions(+), 43 deletions(-) diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index c547b7921dc..d8bb8b41fff 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -459,7 +459,6 @@ cc_library( ":types", ":scoped_profiling_label_wrapper", "@gemmlowp//:fixedpoint", - "@gemmlowp//:profiler", "//third_party/eigen3", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", @@ -512,6 +511,7 @@ cc_library( ":compatibility", ":quantization_util", ":round", + ":scoped_profiling_label_wrapper", ":strided_slice_logic", ":legacy_types", ":tensor", diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index d91c00f755e..502598c27f5 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -28,7 +28,6 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "fixedpoint/fixedpoint.h" -#include "profiling/instrumentation.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" @@ -54,6 +53,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/softmax.h" #include "tensorflow/lite/kernels/internal/reference/strided_slice.h" #include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/scoped_profiling_label_wrapper.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -191,7 +191,7 @@ inline void Relu(const RuntimeShape& input_shape, const T* input_data, template inline void Relu1(const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + ScopedProfilingLabelWrapper label("Relu1 (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const T val = input_data[i]; @@ -204,7 +204,7 @@ inline void Relu1(const RuntimeShape& input_shape, const T* input_data, inline void Relu6(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + ScopedProfilingLabelWrapper label("Relu6 (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; @@ -219,7 +219,7 @@ template inline void ReluX(const tflite::ReluParams& params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + ScopedProfilingLabelWrapper label("Quantized ReluX (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const int32 val = static_cast(input_data[i]); @@ -237,7 +237,7 @@ template inline void ReluX(const tflite::ActivationParams& params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + ScopedProfilingLabelWrapper label("Quantized ReluX (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); const T max_value = params.quantized_activation_max; const T min_value = params.quantized_activation_min; @@ -252,7 +252,7 @@ inline void ReluX(const tflite::ActivationParams& params, inline void LeakyRelu(const tflite::LeakyReluParams& params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("LeakyRelu (not fused)"); + ScopedProfilingLabelWrapper label("LeakyRelu (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; @@ -267,7 +267,7 @@ inline void QuantizeLeakyRelu(const LeakyReluParams& params, T q_alpha, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("LeakyRelu (not fused)"); + ScopedProfilingLabelWrapper label("LeakyRelu (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); static const int32 quantized_min = std::numeric_limits::min(); static const int32 quantized_max = std::numeric_limits::max(); @@ -420,12 +420,11 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, } } - inline void Mul(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int16* input1_data, const RuntimeShape& input2_shape, const int16* input2_data, const RuntimeShape& output_shape, int16* output_data) { - gemmlowp::ScopedProfilingLabel label("Mul/Int16"); + ScopedProfilingLabelWrapper label("Mul/Int16"); const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); @@ -444,7 +443,7 @@ inline void Mul(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int16* input1_data, const RuntimeShape& input2_shape, const int16* input2_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8"); + ScopedProfilingLabelWrapper label("Mul/Int16Uint8"); int32 output_offset = params.output_offset; int32 output_activation_min = params.quantized_activation_min; int32 output_activation_max = params.quantized_activation_max; @@ -581,7 +580,7 @@ inline void Div(const ArithmeticParams& params, const RuntimeShape& output_shape, uint8* output_data) { TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); - gemmlowp::ScopedProfilingLabel label("Div/8bit"); + ScopedProfilingLabelWrapper label("Div/8bit"); const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); @@ -695,7 +694,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const float* input2_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float"); + ScopedProfilingLabelWrapper label("BroadcastSub4DSlow/float"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -736,7 +735,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const uint8* input2_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8"); + ScopedProfilingLabelWrapper label("BroadcastSub4DSlow/uint8"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -800,7 +799,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const int32* input2_data, const RuntimeShape& output_shape, int32* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32"); + ScopedProfilingLabelWrapper label("BroadcastSub4DSlow/int32"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -840,7 +839,7 @@ void BroadcastSub4DSlow(const ArithmeticParams& params, const RuntimeShape& input1_shape, const T* input1_data, const RuntimeShape& input2_shape, const T* input2_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated"); + ScopedProfilingLabelWrapper label("BroadcastSub4DSlow/templated"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -918,7 +917,7 @@ inline void SubWithActivation(const ArithmeticParams& params, const int32* input2_data, const RuntimeShape& output_shape, int32* output_data) { - gemmlowp::ScopedProfilingLabel label("SubWithActivation"); + ScopedProfilingLabelWrapper label("SubWithActivation"); const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { @@ -948,7 +947,7 @@ inline void Sub16(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int16_t* input1_data, const RuntimeShape& input2_shape, const int16_t* input2_data, const RuntimeShape& output_shape, int16_t* output_data) { - gemmlowp::ScopedProfilingLabel label("Sub/Int16"); + ScopedProfilingLabelWrapper label("Sub/Int16"); const int input1_shift = params.input1_shift; const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); @@ -996,7 +995,7 @@ template void Pack(const PackParams& params, const RuntimeShape* const* input_shapes, const Scalar* const* input_data, const RuntimeShape& output_shape, Scalar* output_data) { - gemmlowp::ScopedProfilingLabel label("Pack"); + ScopedProfilingLabelWrapper label("Pack"); const int dimensions = output_shape.DimensionsCount(); int axis = params.axis; int inputs_count = params.inputs_count; @@ -1024,7 +1023,7 @@ template void Unpack(const UnpackParams& params, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape& output_shape, Scalar* const* output_datas) { - gemmlowp::ScopedProfilingLabel label("Unpack"); + ScopedProfilingLabelWrapper label("Unpack"); const int dimensions = input_shape.DimensionsCount(); const int outputs_count = params.num_split; @@ -1058,7 +1057,7 @@ void PackWithScaling(const PackParams& params, const RuntimeShape* const* input_shapes, const uint8* const* input_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("PackWithScaling"); + ScopedProfilingLabelWrapper label("PackWithScaling"); const int dimensions = output_shape.DimensionsCount(); int axis = params.axis; const int32* input_zeropoint = params.input_zeropoint; @@ -1108,7 +1107,7 @@ void DepthConcatenation(const ConcatenationParams& params, const RuntimeShape* const* input_shapes, const Scalar* const* input_data, const RuntimeShape& output_shape, Scalar* output_data) { - gemmlowp::ScopedProfilingLabel label("DepthConcatenation"); + ScopedProfilingLabelWrapper label("DepthConcatenation"); auto params_copy = params; params_copy.axis = 3; Concatenation(params_copy, input_shapes, input_data, output_shape, @@ -1512,7 +1511,7 @@ template void Split(const SplitParams& params, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape* const* output_shapes, Scalar* const* output_data) { - gemmlowp::ScopedProfilingLabel label("Split"); + ScopedProfilingLabelWrapper label("Split"); const int split_dimensions = input_shape.DimensionsCount(); int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis; int outputs_count = params.num_split; @@ -1616,7 +1615,7 @@ inline void LogSoftmax(const SoftmaxParams& params, inline void LogSoftmax(const SoftmaxParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("LogSoftmax/8bit"); + ScopedProfilingLabelWrapper label("LogSoftmax/8bit"); const int32 input_multiplier = params.input_multiplier; const int32 input_left_shift = params.input_left_shift; const int32 reverse_scaling_divisor = params.reverse_scaling_divisor; @@ -1770,7 +1769,7 @@ inline void Requantize(const input_type* input_data, int32_t size, int32_t effective_scale_multiplier, int32_t effective_scale_shift, int32_t input_zeropoint, int32_t output_zeropoint, output_type* output_data) { - gemmlowp::ScopedProfilingLabel label("Requantize"); + ScopedProfilingLabelWrapper label("Requantize"); const bool same_scale = (effective_scale_multiplier == 1 << 30 && effective_scale_shift == 1); if (same_scale) { @@ -1807,7 +1806,7 @@ inline void Requantize(const input_type* input_data, int32_t size, inline void FakeQuant(const tflite::FakeQuantParams& op_params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("FakeQuant"); + ScopedProfilingLabelWrapper label("FakeQuant"); float rmin = op_params.minmax.min; float rmax = op_params.minmax.max; int num_bits = op_params.num_bits; @@ -1860,7 +1859,7 @@ inline void Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& coords_shape, const CoordsT* coords_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Gather"); + ScopedProfilingLabelWrapper label("Gather"); int axis = op_params.axis; if (axis < 0) { axis += input_shape.DimensionsCount(); @@ -1898,7 +1897,7 @@ inline void GatherNd(const RuntimeShape& params_shape, const RuntimeShape& indices_shape, const IndicesT* indices_data, const RuntimeShape& output_shape, ParamsT* output_data) { - gemmlowp::ScopedProfilingLabel label("GatherNd"); + ScopedProfilingLabelWrapper label("GatherNd"); int n_slices = 1; int slice_size = 1; @@ -1935,7 +1934,7 @@ inline void ScatterNd(const RuntimeShape& indices_shape, const RuntimeShape& updates_shape, const UpdatesT* updates_data, const RuntimeShape& output_shape, UpdatesT* output_data) { - gemmlowp::ScopedProfilingLabel label("ScatterNd"); + ScopedProfilingLabelWrapper label("ScatterNd"); int n_slices = 1; int slice_size = 1; @@ -2043,7 +2042,7 @@ inline void SpaceToBatchND( const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, const RuntimeShape& unextended_input3_shape, const int32* paddings_data, const RuntimeShape& unextended_output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); + ScopedProfilingLabelWrapper label("SpaceToBatchND"); TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); const RuntimeShape input1_shape = @@ -2101,7 +2100,7 @@ inline void BatchToSpaceND( const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, const RuntimeShape& unextended_input3_shape, const int32* crops_data, const RuntimeShape& unextended_output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); + ScopedProfilingLabelWrapper label("BatchToSpaceND"); TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); const RuntimeShape input1_shape = @@ -2351,7 +2350,7 @@ inline void Slice(const tflite::SliceParams& op_params, template inline void Exp(const T* input_data, const size_t num_elements, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Exp"); + ScopedProfilingLabelWrapper label("Exp"); for (size_t idx = 0; idx < num_elements; ++idx) { output_data[idx] = std::exp(input_data[idx]); } @@ -2482,7 +2481,7 @@ inline bool Mean(const T* input_data, const int* input_dims, const int* output_dims, const int output_num_dims, const int* axis, const int num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis, U* temp_sum) { - gemmlowp::ScopedProfilingLabel label("Mean"); + ScopedProfilingLabelWrapper label("Mean"); // Reset output data. size_t num_outputs = 1; for (int idx = 0; idx < output_num_dims; ++idx) { @@ -2536,7 +2535,7 @@ inline void Mean(const tflite::MeanParams& op_params, const RuntimeShape& unextended_input_shape, const T* input_data, const RuntimeShape& unextended_output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Mean4D"); + ScopedProfilingLabelWrapper label("Mean4D"); // Current implementation only supports dimension equals 4 and simultaneous // reduction over width and height. @@ -2581,7 +2580,7 @@ inline void Mean(const tflite::MeanParams& op_params, float input_scale, const RuntimeShape& unextended_output_shape, uint8_t* output_data, int32 output_zero_point, float output_scale) { - gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8"); + ScopedProfilingLabelWrapper label("Mean4D/Uint8"); // Current implementation only supports dimension equals 4 and simultaneous // reduction over width and height. @@ -2647,11 +2646,9 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point, bool compute_sum) { const bool uint8_case = std::is_same::value; if (uint8_case) { - gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Uint8" - : "Mean/Uint8"); + ScopedProfilingLabelWrapper label(compute_sum ? "Sum/Uint8" : "Mean/Uint8"); } else { - gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Int8" - : "Mean/Int8"); + ScopedProfilingLabelWrapper label(compute_sum ? "Sum/Int8" : "Mean/Int8"); } // Reset output data. size_t num_outputs = 1; @@ -3248,7 +3245,7 @@ template void Reverse(int axis, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape& output_shape, Scalar* output_data) { - gemmlowp::ScopedProfilingLabel label("Reverse"); + ScopedProfilingLabelWrapper label("Reverse"); int outer_size = 1; for (int i = 0; i < axis; ++i) { @@ -3276,7 +3273,7 @@ void ReverseSequence(const TS* seq_lengths, const int seq_dim, const int batch_dim, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape& output_shape, Scalar* output_data) { - gemmlowp::ScopedProfilingLabel label("ReverseSequence"); + ScopedProfilingLabelWrapper label("ReverseSequence"); int outer_size = 1; int outer_dim = std::min(batch_dim, seq_dim); @@ -3353,7 +3350,7 @@ void ReverseSequence(const TS* seq_lengths, const int seq_dim, template inline void HardSwish(const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("ReferenceHardSwish/Float"); + ScopedProfilingLabelWrapper label("ReferenceHardSwish/Float"); auto matching_size = MatchingFlatSize(input_shape, output_shape); const T* in_end = input_data + matching_size; for (; input_data < in_end; input_data++, output_data++) { @@ -3387,7 +3384,7 @@ template inline void HardSwish(const HardSwishParams& params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("ReferenceHardSwish/Quantized"); + ScopedProfilingLabelWrapper label("ReferenceHardSwish/Quantized"); const int flat_size = MatchingFlatSize(input_shape, output_shape); From b1e14472c8899e65274f9a43a1a682e2ec28ad74 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 27 Nov 2019 10:55:59 -0800 Subject: [PATCH 046/279] NFC: Add headers for MLIR builders and standard types PiperOrigin-RevId: 282799238 Change-Id: Ifb73b4953dc85ba79670348f34afc6fb36c33bfb --- tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc | 2 ++ tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc | 1 + 2 files changed, 3 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index da3d26c1b72..20483691a92 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -26,11 +26,13 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/SMLoc.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/OpDefinition.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir #include "mlir/IR/OperationSupport.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index ad49c8970cf..c672d624944 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir From e2b52c093541d45ab22e82587c5a0083caf84e6b Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 27 Nov 2019 10:58:43 -0800 Subject: [PATCH 047/279] Remove dependency on lib/core:status within tensorflow/core/platform PiperOrigin-RevId: 282799700 Change-Id: I6bbff913e7f07189a92328faffeaa7197925eafa --- tensorflow/core/platform/BUILD | 4 ++-- tensorflow/core/platform/cloud/auth_provider.h | 3 ++- .../platform/cloud/compute_engine_metadata_client.h | 2 +- tensorflow/core/platform/cloud/curl_http_request.h | 3 ++- tensorflow/core/platform/cloud/file_block_cache.h | 3 ++- tensorflow/core/platform/cloud/gcs_file_system.h | 2 +- tensorflow/core/platform/cloud/http_request.h | 3 ++- tensorflow/core/platform/cloud/http_request_fake.h | 3 ++- tensorflow/core/platform/cloud/oauth_client.h | 3 ++- .../core/platform/cloud/ram_file_block_cache.h | 3 ++- .../core/platform/cloud/retrying_file_system.h | 2 +- tensorflow/core/platform/cloud/retrying_utils.h | 3 ++- tensorflow/core/platform/cloud/time_util.h | 2 +- tensorflow/core/platform/cloud/zone_provider.h | 3 ++- tensorflow/core/platform/default/build_refactor.bzl | 12 ++++++------ .../core/platform/default/human_readable_json.cc | 2 +- .../core/platform/default/posix_file_system.cc | 2 +- tensorflow/core/platform/env.h | 2 +- tensorflow/core/platform/error.cc | 2 +- tensorflow/core/platform/error.h | 2 +- tensorflow/core/platform/file_system_helper.cc | 2 +- tensorflow/core/platform/file_system_helper.h | 2 +- .../core/platform/hadoop/hadoop_file_system.cc | 2 +- tensorflow/core/platform/human_readable_json.h | 2 +- tensorflow/core/platform/load_library.h | 2 +- 25 files changed, 40 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 8c676bc16e2..1c85e9d0769 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -173,9 +173,9 @@ cc_library( hdrs = ["error.h"], deps = [ ":platform", + ":status", + ":strcat", ":types", - "//tensorflow/core/lib/core:status", - "//tensorflow/core/platform:strcat", ], ) diff --git a/tensorflow/core/platform/cloud/auth_provider.h b/tensorflow/core/platform/cloud/auth_provider.h index 7347bc626d8..4c219b70221 100644 --- a/tensorflow/core/platform/cloud/auth_provider.h +++ b/tensorflow/core/platform/cloud/auth_provider.h @@ -17,8 +17,9 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_ #include + #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h index 7f060327da5..d7611615606 100644 --- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ #define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/cloud/retrying_utils.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h index 9ad75e52f20..ddb1599e871 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.h +++ b/tensorflow/core/platform/cloud/curl_http_request.h @@ -19,14 +19,15 @@ limitations under the License. #include #include #include + #include #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/file_block_cache.h b/tensorflow/core/platform/cloud/file_block_cache.h index c98b10640fa..3e66a9937a6 100644 --- a/tensorflow/core/platform/cloud/file_block_cache.h +++ b/tensorflow/core/platform/cloud/file_block_cache.h @@ -22,11 +22,12 @@ limitations under the License. #include #include #include -#include "tensorflow/core/lib/core/status.h" + #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 0bd95f7c6b6..a4d3bcc8f05 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/auth_provider.h" #include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" #include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h" @@ -32,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/cloud/retrying_file_system.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h index e925eefb1f2..91825b5958a 100644 --- a/tensorflow/core/platform/cloud/http_request.h +++ b/tensorflow/core/platform/cloud/http_request.h @@ -19,12 +19,13 @@ limitations under the License. #include #include #include + #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h index 0a1164b64a7..f1bed661715 100644 --- a/tensorflow/core/platform/cloud/http_request_fake.h +++ b/tensorflow/core/platform/cloud/http_request_fake.h @@ -20,14 +20,15 @@ limitations under the License. #include #include #include + #include #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/cloud/oauth_client.h b/tensorflow/core/platform/cloud/oauth_client.h index 519d69acf98..ed8bf257253 100644 --- a/tensorflow/core/platform/cloud/oauth_client.h +++ b/tensorflow/core/platform/cloud/oauth_client.h @@ -17,10 +17,11 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ #include + #include "include/json/json.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/ram_file_block_cache.h b/tensorflow/core/platform/cloud/ram_file_block_cache.h index 46fb9a35b88..d418a0fb6b0 100644 --- a/tensorflow/core/platform/cloud/ram_file_block_cache.h +++ b/tensorflow/core/platform/cloud/ram_file_block_cache.h @@ -22,12 +22,13 @@ limitations under the License. #include #include #include -#include "tensorflow/core/lib/core/status.h" + #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/cloud/file_block_cache.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h index 9659edd890e..5e85447fd3d 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/cloud/retrying_file_system.h @@ -21,11 +21,11 @@ limitations under the License. #include #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/cloud/retrying_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/cloud/retrying_utils.h index 1a7ce1b122b..70b98463477 100644 --- a/tensorflow/core/platform/cloud/retrying_utils.h +++ b/tensorflow/core/platform/cloud/retrying_utils.h @@ -17,7 +17,8 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_ #include -#include "tensorflow/core/lib/core/status.h" + +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/time_util.h b/tensorflow/core/platform/cloud/time_util.h index d6d4bc499fe..944efe9bbd4 100644 --- a/tensorflow/core/platform/cloud/time_util.h +++ b/tensorflow/core/platform/cloud/time_util.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ #define TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/zone_provider.h b/tensorflow/core/platform/cloud/zone_provider.h index 421b6a7e1af..6f809ceb381 100644 --- a/tensorflow/core/platform/cloud/zone_provider.h +++ b/tensorflow/core/platform/cloud/zone_provider.h @@ -17,8 +17,9 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_ #include + #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/build_refactor.bzl b/tensorflow/core/platform/default/build_refactor.bzl index e7eddeb3343..eb12e6b5585 100644 --- a/tensorflow/core/platform/default/build_refactor.bzl +++ b/tensorflow/core/platform/default/build_refactor.bzl @@ -80,7 +80,6 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { "//tensorflow/core/lib/core:blocking_counter", "//tensorflow/core/lib/core:error_codes_proto_cc", "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", "//tensorflow/core/lib/core:stringpiece", "//tensorflow/core/lib/io:path", "//tensorflow/core/platform", @@ -97,6 +96,7 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:setround", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:stringprintf", "//tensorflow/core/platform:strcat", @@ -132,9 +132,9 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { ], "deps": [ "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", - "//tensorflow/core/platform:strcat", "//tensorflow/core/platform:protobuf", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:strcat", ], "visibility": ["//visibility:private"], "tags": ["no_oss", "manual"], @@ -149,7 +149,7 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { ], "deps": [ "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:status", ], "visibility": ["//visibility:private"], "tags": ["no_oss", "manual"], @@ -405,7 +405,6 @@ TF_WINDOWS_PLATFORM_LIBRARIES = { "//tensorflow/core/lib/core:blocking_counter", "//tensorflow/core/lib/core:error_codes_proto_cc", "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", "//tensorflow/core/lib/core:stringpiece", "//tensorflow/core/lib/io:path", "//tensorflow/core/platform", @@ -422,6 +421,7 @@ TF_WINDOWS_PLATFORM_LIBRARIES = { "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:setround", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:stringprintf", "//tensorflow/core/platform:strcat", @@ -458,7 +458,7 @@ TF_WINDOWS_PLATFORM_LIBRARIES = { ], "deps": [ "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:windows_wide_char_impl", ], "visibility": ["//visibility:private"], diff --git a/tensorflow/core/platform/default/human_readable_json.cc b/tensorflow/core/platform/default/human_readable_json.cc index c3a61a3d58c..2ecbf437800 100644 --- a/tensorflow/core/platform/default/human_readable_json.cc +++ b/tensorflow/core/platform/default/human_readable_json.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/strcat.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/posix_file_system.cc b/tensorflow/core/platform/default/posix_file_system.cc index 56c00279e6b..106a0412fb7 100644 --- a/tensorflow/core/platform/default/posix_file_system.cc +++ b/tensorflow/core/platform/default/posix_file_system.cc @@ -28,12 +28,12 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/default/posix_file_system.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/error.h" #include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/protobuf/error_codes.pb.h" diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index be8399c879b..cd6a6488e52 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -24,7 +24,6 @@ limitations under the License. #include #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" @@ -32,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/numa.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/error.cc b/tensorflow/core/platform/error.cc index 00ddf1dc241..cb09a3a86cc 100644 --- a/tensorflow/core/platform/error.cc +++ b/tensorflow/core/platform/error.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/strcat.h" namespace tensorflow { diff --git a/tensorflow/core/platform/error.h b/tensorflow/core/platform/error.h index 3ba3e749c34..0b08ac36682 100644 --- a/tensorflow/core/platform/error.h +++ b/tensorflow/core/platform/error.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc index e47fa133c6d..c909b36688e 100644 --- a/tensorflow/core/platform/file_system_helper.cc +++ b/tensorflow/core/platform/file_system_helper.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/threadpool.h" diff --git a/tensorflow/core/platform/file_system_helper.h b/tensorflow/core/platform/file_system_helper.h index 8d812b0e381..7427dea77ef 100644 --- a/tensorflow/core/platform/file_system_helper.h +++ b/tensorflow/core/platform/file_system_helper.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 59c3fe2540f..60668ff4f61 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" @@ -26,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" #include "third_party/hadoop/hdfs.h" namespace tensorflow { diff --git a/tensorflow/core/platform/human_readable_json.h b/tensorflow/core/platform/human_readable_json.h index 49908eac7c8..f6830e20207 100644 --- a/tensorflow/core/platform/human_readable_json.h +++ b/tensorflow/core/platform/human_readable_json.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ #define TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/load_library.h b/tensorflow/core/platform/load_library.h index c7eeb2918ca..01efd4c1d01 100644 --- a/tensorflow/core/platform/load_library.h +++ b/tensorflow/core/platform/load_library.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_ #define TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_ -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { From b159db043c57de83789c155abce9a022762502ee Mon Sep 17 00:00:00 2001 From: Daniel Situnayake Date: Wed, 27 Nov 2019 11:02:45 -0800 Subject: [PATCH 048/279] Add supported devices to TensorFlow Lite for Microcontrollers documentation PiperOrigin-RevId: 282800486 Change-Id: I2f85ab5639f9828768d6461a4c33c910b783c6c2 --- tensorflow/lite/g3doc/microcontrollers/index.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/lite/g3doc/microcontrollers/index.md b/tensorflow/lite/g3doc/microcontrollers/index.md index b78b131784e..2ead371d4b4 100644 --- a/tensorflow/lite/g3doc/microcontrollers/index.md +++ b/tensorflow/lite/g3doc/microcontrollers/index.md @@ -35,6 +35,8 @@ There are example applications available for the following development boards: * [Arduino Nano 33 BLE Sense](https://store.arduino.cc/usa/nano-33-ble-sense-with-headers) * [SparkFun Edge](https://www.sparkfun.com/products/15170) * [STM32F746 Discovery kit](https://www.st.com/en/evaluation-tools/32f746gdiscovery.html) +* [Adafruit EdgeBadge](https://www.adafruit.com/product/4400) +* [Adafruit TensorFlow Lite for Microcontrollers Kit] To learn more about the libraries and examples, see [Get started with microcontrollers](get_started.md). From d65366e77639769165bff7277c0b97260d4a6d1a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 27 Nov 2019 11:02:54 -0800 Subject: [PATCH 049/279] Create a //third_party/tensorflow/c:tf_status_headers target and use it as a dependency in targets under //third_party/tensorflow/python, to avoid duplicating symbols that also come from pywrap_tensorflow_internal. PiperOrigin-RevId: 282800548 Change-Id: I253c3bbd9937f12cf4c12557b4397775403fc45b --- tensorflow/c/BUILD | 6 ++++++ tensorflow/python/BUILD | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index fd58c1b173e..cabc3b21e45 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -196,6 +196,12 @@ cc_library( }), ) +cc_library( + name = "tf_status_headers", + hdrs = ["tf_status.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "tf_file_statistics", hdrs = ["tf_file_statistics.h"], diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2f0241146b8..3d4f22583a2 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -421,7 +421,7 @@ cc_library( srcs = ["lib/core/py_exception_registry.cc"], hdrs = ["lib/core/py_exception_registry.h"], deps = [ - "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_headers", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//third_party/python_runtime:headers", @@ -456,7 +456,7 @@ cc_library( features = ["-parse_headers"], deps = [ ":py_exception_registry", - "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_headers", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//third_party/python_runtime:headers", From 8053a43598b7d4e0e5b436d7fa2b803c7b6c9f7f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 27 Nov 2019 11:04:01 -0800 Subject: [PATCH 050/279] Explicitly export files needed by other packages PiperOrigin-RevId: 282800860 Change-Id: Ie7b8863629c0e4b2169c654714bf4a2338d667e5 --- tensorflow/core/kernels/BUILD | 70 ++++++++++++++++++- .../python/data/experimental/benchmarks/BUILD | 5 ++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f66102ab3ac..0872c0b0611 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -8210,4 +8210,72 @@ tf_cc_shared_object( ], ) -exports_files(["ops_testutil.h"]) +exports_files([ + "cwise_op_abs.cc", + "cwise_op_add_1.cc", + "cwise_op_add_2.cc", + "cwise_op_atan2.cc", + "cwise_op_cos.cc", + "cwise_op_div.cc", + "cwise_op_equal_to_1.cc", + "cwise_op_equal_to_2.cc", + "cwise_op_exp.cc", + "cwise_op_floor.cc", + "cwise_op_floor_div.cc", + "cwise_op_floor_mod.cc", + "cwise_op_gpu_add.cu.cc", + "cwise_op_gpu_atan2.cu.cc", + "cwise_op_gpu_cos.cu.cc", + "cwise_op_gpu_div.cu.cc", + "cwise_op_gpu_equal_to.cu.cc", + "cwise_op_gpu_exp.cu.cc", + "cwise_op_gpu_floor.cu.cc", + "cwise_op_gpu_floor_div.cu.cc", + "cwise_op_gpu_greater.cu.cc", + "cwise_op_gpu_greater_equal.cu.cc", + "cwise_op_gpu_less.cu.cc", + "cwise_op_gpu_less_equal.cu.cc", + "cwise_op_gpu_logical_not.cu.cc", + "cwise_op_gpu_maximum.cu.cc", + "cwise_op_gpu_minimum.cu.cc", + "cwise_op_gpu_mul.cu.cc", + "cwise_op_gpu_neg.cu.cc", + "cwise_op_gpu_round.cu.cc", + "cwise_op_gpu_rsqrt.cu.cc", + "cwise_op_gpu_select.cu.cc", + "cwise_op_gpu_sigmoid.cu.cc", + "cwise_op_gpu_sin.cu.cc", + "cwise_op_gpu_sqrt.cu.cc", + "cwise_op_gpu_squared_difference.cu.cc", + "cwise_op_gpu_sub.cu.cc", + "cwise_op_gpu_tanh.cu.cc", + "cwise_op_greater.cc", + "cwise_op_greater_equal.cc", + "cwise_op_less.cc", + "cwise_op_less_equal.cc", + "cwise_op_logical_not.cc", + "cwise_op_maximum.cc", + "cwise_op_minimum.cc", + "cwise_op_mul_1.cc", + "cwise_op_mul_2.cc", + "cwise_op_neg.cc", + "cwise_op_not_equal_to_2.cc", + "cwise_op_round.cc", + "cwise_op_rsqrt.cc", + "cwise_op_select.cc", + "cwise_op_sigmoid.cc", + "cwise_op_sin.cc", + "cwise_op_sqrt.cc", + "cwise_op_square.cc", + "cwise_op_squared_difference.cc", + "cwise_op_sub.cc", + "cwise_op_tanh.cc", + "dequantize_op.cc", + "ops_testutil.h", + "quantize_and_dequantize_op.cc", + "quantize_op.cc", + "sparse_cross_op.cc", + "sparse_fill_empty_rows_op.cc", + "sparse_reshape_op.cc", + "unary_ops_composition.cc", +]) diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD index 0540fa069d3..0759a345b29 100644 --- a/tensorflow/python/data/experimental/benchmarks/BUILD +++ b/tensorflow/python/data/experimental/benchmarks/BUILD @@ -7,6 +7,11 @@ package( exports_files(["LICENSE"]) +exports_files( + ["autotune_benchmark.py"], + visibility = ["//tensorflow:internal"], +) + tf_py_test( name = "autotune_benchmark", srcs = ["autotune_benchmark.py"], From 235f0e2a89c0a2c866d6dc24932dd5c281a503f2 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 27 Nov 2019 11:13:16 -0800 Subject: [PATCH 051/279] Expose the definition of LocalRendezvous to its users. Previously, creating a LocalRendezvous required a dynamic allocation and refcount manipulation, and invoking it required making virtual method calls. This change allows users (especially IntraProcessRendezvous) to instantiate a LocalRendezvous directly as a member, which avoids these overheads. PiperOrigin-RevId: 282802675 Change-Id: I81d91a2ddd2365bb5612c402db0e1bf66a9ea5f4 --- .../core/common_runtime/rendezvous_mgr.cc | 26 +- .../core/common_runtime/rendezvous_mgr.h | 10 +- tensorflow/core/framework/local_rendezvous.cc | 300 ++++++++++++++++++ tensorflow/core/framework/local_rendezvous.h | 75 +++++ tensorflow/core/framework/rendezvous.cc | 289 +---------------- 5 files changed, 402 insertions(+), 298 deletions(-) create mode 100644 tensorflow/core/framework/local_rendezvous.cc create mode 100644 tensorflow/core/framework/local_rendezvous.h diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index cd697574663..0d5e79667db 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -33,16 +33,16 @@ limitations under the License. namespace tensorflow { IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr) - : device_mgr_(device_mgr), local_(NewLocalRendezvous()) {} + : device_mgr_(device_mgr) {} -IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); } +IntraProcessRendezvous::~IntraProcessRendezvous() {} -Status IntraProcessRendezvous::Send(const ParsedKey& parsed, +Status IntraProcessRendezvous::Send(const ParsedKey& key, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) { - VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey(); + VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey(); // Buffers "val" and "device_context" in local_. - return local_->Send(parsed, args, val, is_dead); + return local_.Send(key, args, val, is_dead); } void IntraProcessRendezvous::SameWorkerRecvDone( @@ -116,16 +116,16 @@ void IntraProcessRendezvous::SameWorkerRecvDone( out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute); } -void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, - const Rendezvous::Args& recv_args, +void IntraProcessRendezvous::RecvAsync(const ParsedKey& key, + const Rendezvous::Args& args, DoneCallback done) { - VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey(); + VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey(); MEMDEBUG_CACHE_OP("RecvAsync"); // Recv the tensor from local_. - local_->RecvAsync( - parsed, recv_args, - [this, parsed, done = std::move(done)]( + local_.RecvAsync( + key, args, + [this, key, done = std::move(done)]( const Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) mutable { @@ -141,7 +141,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, }; if (status.ok() && in.IsInitialized()) { - SameWorkerRecvDone(parsed, send_args, recv_args, in, out, + SameWorkerRecvDone(key, send_args, recv_args, in, out, std::move(final_callback)); } else { final_callback(status); @@ -151,7 +151,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, void IntraProcessRendezvous::StartAbort(const Status& s) { CHECK(!s.ok()); - local_->StartAbort(s); + local_.StartAbort(s); } } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h index 0602ac3e936..a9d3de122f0 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.h +++ b/tensorflow/core/common_runtime/rendezvous_mgr.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/local_rendezvous.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -31,12 +32,11 @@ namespace tensorflow { // IntraProcessRendezvous is a Rendezvous which expects all producers // and consumers to be devices immediately accessible within the -// process. That is, it will never be necessary to perform an RPC to +// process. That is, it will never be necessary to perform an RPC to // communicate with either. // -// Buffering of Tensor values is delegated to a "local" Rendezvous -// obtained from NewLocalRendezvous(). This class just adds -// functionality to coordinate multiple process-local devices. +// Buffering of Tensor values is delegated to a `LocalRendezvous`. This class +// just adds functionality to coordinate multiple process-local devices. class IntraProcessRendezvous : public Rendezvous { public: explicit IntraProcessRendezvous(const DeviceMgr* device_mgr); @@ -57,7 +57,7 @@ class IntraProcessRendezvous : public Rendezvous { private: const DeviceMgr* device_mgr_; - Rendezvous* local_; // Owns a Ref on this object. + LocalRendezvous local_; ~IntraProcessRendezvous() override; diff --git a/tensorflow/core/framework/local_rendezvous.cc b/tensorflow/core/framework/local_rendezvous.cc new file mode 100644 index 00000000000..c21974552e7 --- /dev/null +++ b/tensorflow/core/framework/local_rendezvous.cc @@ -0,0 +1,300 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/local_rendezvous.h" + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents a blocked Send() or Recv() call in the rendezvous. +struct LocalRendezvous::Item { + enum Type { kSend = 0, kRecv = 1 }; + + Item(Rendezvous::Args send_args, const Tensor& value, bool is_dead) + : Item(send_args, kSend) { + send_state.value.Init(value); + send_state.is_dead = is_dead; + } + + Item(Rendezvous::Args recv_args, Rendezvous::DoneCallback waiter, + CancellationToken cancellation_token) + : Item(recv_args, kRecv) { + recv_state.waiter.Init(std::move(waiter)); + recv_state.cancellation_token = cancellation_token; + } + + ~Item() { + if (args.device_context) { + args.device_context->Unref(); + } + if (type == kSend) { + send_state.value.Destroy(); + } else { + recv_state.waiter.Destroy(); + } + } + + const Rendezvous::Args args; + const Type type; + + // Link to next item in an ItemQueue. + Item* next = nullptr; + + // The validity of `send_state` or `recv_state` is determined by `type == + // kSend` or `type == kRecv` respectively. + union { + struct { + ManualConstructor value; + bool is_dead; + } send_state; + struct { + ManualConstructor waiter; + CancellationToken cancellation_token; + } recv_state; + }; + + private: + Item(Rendezvous::Args args, Type type) : args(args), type(type) { + if (args.device_context) { + args.device_context->Ref(); + } + } +}; + +void LocalRendezvous::ItemQueue::push_back(Item* item) { + if (TF_PREDICT_TRUE(head == nullptr)) { + // The queue is empty. + head = item; + tail = item; + } else { + DCHECK_EQ(tail->type, item->type); + tail->next = item; + tail = item; + } +} + +LocalRendezvous::~LocalRendezvous() { + if (!table_.empty()) { + StartAbort(errors::Cancelled("LocalRendezvous deleted")); + } +} + +namespace { +uint64 KeyHash(const StringPiece& k) { return Hash64(k.data(), k.size()); } +} // namespace + +Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& send_args, + const Tensor& val, const bool is_dead) { + uint64 key_hash = KeyHash(key.FullKey()); + DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); + + mu_.lock(); + if (!status_.ok()) { + // Rendezvous has been aborted. + Status s = status_; + mu_.unlock(); + return s; + } + + ItemQueue* queue = &table_[key_hash]; + if (queue->head == nullptr || queue->head->type == Item::kSend) { + // There is no waiter for this message. Append the message + // into the queue. The waiter will pick it up when arrives. + // Only send-related fields need to be filled. + // TODO(b/143786186): Investigate moving the allocation of `Item` outside + // the lock. + DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). "; + queue->push_back(new Item(send_args, val, is_dead)); + mu_.unlock(); + return Status::OK(); + } + + DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). "; + // There is an earliest waiter to consume this message. + Item* item = queue->head; + + // Delete the queue when the last element has been consumed. + if (item->next == nullptr) { + DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; + table_.erase(key_hash); + } else { + queue->head = item->next; + } + mu_.unlock(); + + // Notify the waiter by invoking its done closure, outside the + // lock. + DCHECK_EQ(item->type, Item::kRecv); + (*item->recv_state.waiter)(Status::OK(), send_args, item->args, val, is_dead); + delete item; + return Status::OK(); +} + +void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& recv_args, + Rendezvous::DoneCallback done) { + uint64 key_hash = KeyHash(key.FullKey()); + DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); + + mu_.lock(); + if (!status_.ok()) { + // Rendezvous has been aborted. + Status s = status_; + mu_.unlock(); + done(s, Rendezvous::Args(), recv_args, Tensor(), false); + return; + } + + ItemQueue* queue = &table_[key_hash]; + if (queue->head == nullptr || queue->head->type == Item::kRecv) { + // There is no message to pick up. + // Only recv-related fields need to be filled. + CancellationManager* cm = recv_args.cancellation_manager; + CancellationToken token = CancellationManager::kInvalidToken; + bool already_cancelled = false; + if (cm != nullptr) { + token = cm->get_cancellation_token(); + already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] { + Item* item = nullptr; + { + mutex_lock l(mu_); + ItemQueue* queue = &table_[key_hash]; + // Find an item in the queue with a cancellation token that matches + // `token`, and remove it. + if (queue->head != nullptr && queue->head->type == Item::kRecv) { + for (Item *prev = nullptr, *curr = queue->head; curr != nullptr; + prev = curr, curr = curr->next) { + if (curr->recv_state.cancellation_token == token) { + item = curr; + if (queue->head->next == nullptr) { + // We have a single-element queue, so we can erase it from + // the table. + table_.erase(key_hash); + } else { + // Remove the current item from the queue. + if (curr == queue->head) { + DCHECK_EQ(prev, nullptr); + queue->head = curr->next; + } else { + DCHECK_NE(prev, nullptr); + prev->next = curr->next; + } + if (queue->tail == curr) { + queue->tail = prev; + } + } + break; + } + } + } + } + + if (item != nullptr) { + (*item->recv_state.waiter)( + StatusGroup::MakeDerived( + errors::Cancelled("RecvAsync is cancelled.")), + Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false); + delete item; + } + }); + } + if (already_cancelled) { + mu_.unlock(); + done(StatusGroup::MakeDerived( + errors::Cancelled("RecvAsync is cancelled.")), + Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false); + return; + } + + DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). "; + + // TODO(b/143786186): Investigate moving the allocation of `Item` outside + // the lock. + if (cm != nullptr) { + // NOTE(mrry): We must wrap `done` with code that deregisters the + // cancellation callback before calling the `done` callback, because the + // cancellation manager may no longer be live after `done` is called. + queue->push_back(new Item( + recv_args, + [cm, token, done = std::move(done)]( + const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { + cm->TryDeregisterCallback(token); + done(s, send_args, recv_args, v, dead); + }, + token)); + } else { + queue->push_back(new Item(recv_args, std::move(done), token)); + } + + mu_.unlock(); + return; + } + + DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). "; + // A message has already arrived and is queued in the table under + // this key. Consumes the message and invokes the done closure. + Item* item = queue->head; + + // Delete the queue when the last element has been consumed. + if (item->next == nullptr) { + DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; + table_.erase(key_hash); + } else { + queue->head = item->next; + } + mu_.unlock(); + + // Invoke done() without holding the table lock. + DCHECK_EQ(item->type, Item::kSend); + done(Status::OK(), item->args, recv_args, *item->send_state.value, + item->send_state.is_dead); + delete item; +} + +void LocalRendezvous::StartAbort(const Status& status) { + CHECK(!status.ok()); + Table table; + { + mutex_lock l(mu_); + status_.Update(status); + table_.swap(table); + } + for (auto& p : table) { + Item* item = p.second.head; + while (item != nullptr) { + if (item->type == Item::kRecv) { + (*item->recv_state.waiter)(status, Rendezvous::Args(), + Rendezvous::Args(), Tensor(), false); + } + Item* to_delete = item; + item = item->next; + delete to_delete; + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/local_rendezvous.h b/tensorflow/core/framework/local_rendezvous.h new file mode 100644 index 00000000000..07c52712de7 --- /dev/null +++ b/tensorflow/core/framework/local_rendezvous.h @@ -0,0 +1,75 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ + +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Implements the basic logic of matching Send and Recv operations. See +// RendezvousInterface for more details. +// +// NOTE: Most users will use a class that wraps LocalRendezvous, such as +// IntraProcessRendezvous or RemoteRendezvous. This class does not implement +// RendezvousInterface because virtual dispatch to LocalRendezvous methods +// is not expected to be needed. +class LocalRendezvous { + public: + LocalRendezvous() = default; + ~LocalRendezvous(); + + Status Send(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& send_args, const Tensor& val, + const bool is_dead); + void RecvAsync(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& recv_args, + Rendezvous::DoneCallback done); + void StartAbort(const Status& status); + + private: + struct Item; + + // By invariant, the item queue under each key is of the form + // [item.type == kSend]* meaning each item is a sent message. + // or + // [item.type == kRecv]* meaning each item is a waiter. + struct ItemQueue { + void push_back(Item* item); + + Item* head = nullptr; + Item* tail = nullptr; + }; + + typedef gtl::FlatMap Table; + + // TODO(zhifengc): shard table_. + mutex mu_; + Table table_ GUARDED_BY(mu_); + Status status_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvous); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index 18be6238225..764f8995d02 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/local_rendezvous.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -148,301 +149,29 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args, } namespace { -class LocalRendezvousImpl : public Rendezvous { +class LocalRendezvousWrapper : public Rendezvous { public: - explicit LocalRendezvousImpl() {} + LocalRendezvousWrapper() = default; Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, const bool is_dead) override { - uint64 key_hash = KeyHash(key.FullKey()); - DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); - - mu_.lock(); - if (!status_.ok()) { - // Rendezvous has been aborted. - Status s = status_; - mu_.unlock(); - return s; - } - - ItemQueue* queue = &table_[key_hash]; - if (queue->head == nullptr || queue->head->type == Item::kSend) { - // There is no waiter for this message. Append the message - // into the queue. The waiter will pick it up when arrives. - // Only send-related fields need to be filled. - // TODO(b/143786186): Investigate moving the allocation of `Item` outside - // the lock. - DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). "; - queue->push_back(new Item(send_args, val, is_dead)); - mu_.unlock(); - return Status::OK(); - } - - DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). "; - // There is an earliest waiter to consume this message. - Item* item = queue->head; - - // Delete the queue when the last element has been consumed. - if (item->next == nullptr) { - DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; - table_.erase(key_hash); - } else { - queue->head = item->next; - } - mu_.unlock(); - - // Notify the waiter by invoking its done closure, outside the - // lock. - DCHECK_EQ(item->type, Item::kRecv); - (*item->recv_state.waiter)(Status::OK(), send_args, item->args, val, - is_dead); - delete item; - return Status::OK(); + return impl_.Send(key, send_args, val, is_dead); } void RecvAsync(const ParsedKey& key, const Args& recv_args, DoneCallback done) override { - uint64 key_hash = KeyHash(key.FullKey()); - DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); - - mu_.lock(); - if (!status_.ok()) { - // Rendezvous has been aborted. - Status s = status_; - mu_.unlock(); - done(s, Args(), recv_args, Tensor(), false); - return; - } - - ItemQueue* queue = &table_[key_hash]; - if (queue->head == nullptr || queue->head->type == Item::kRecv) { - // There is no message to pick up. - // Only recv-related fields need to be filled. - CancellationManager* cm = recv_args.cancellation_manager; - CancellationToken token = CancellationManager::kInvalidToken; - bool already_cancelled = false; - if (cm != nullptr) { - token = cm->get_cancellation_token(); - already_cancelled = !cm->RegisterCallback(token, [this, token, - key_hash] { - Item* item = nullptr; - { - mutex_lock l(mu_); - ItemQueue* queue = &table_[key_hash]; - // Find an item in the queue with a cancellation token that matches - // `token`, and remove it. - if (queue->head != nullptr && queue->head->type == Item::kRecv) { - for (Item *prev = nullptr, *curr = queue->head; curr != nullptr; - prev = curr, curr = curr->next) { - if (curr->recv_state.cancellation_token == token) { - item = curr; - if (queue->head->next == nullptr) { - // We have a single-element queue, so we can erase it from - // the table. - table_.erase(key_hash); - } else { - // Remove the current item from the queue. - if (curr == queue->head) { - DCHECK_EQ(prev, nullptr); - queue->head = curr->next; - } else { - DCHECK_NE(prev, nullptr); - prev->next = curr->next; - } - if (queue->tail == curr) { - queue->tail = prev; - } - } - break; - } - } - } - } - - if (item != nullptr) { - (*item->recv_state.waiter)( - StatusGroup::MakeDerived( - errors::Cancelled("RecvAsync is cancelled.")), - Args(), item->args, Tensor(), /*is_dead=*/false); - delete item; - } - }); - } - if (already_cancelled) { - mu_.unlock(); - done(StatusGroup::MakeDerived( - errors::Cancelled("RecvAsync is cancelled.")), - Args(), recv_args, Tensor(), /*is_dead=*/false); - return; - } - - DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). "; - - // TODO(b/143786186): Investigate moving the allocation of `Item` outside - // the lock. - if (cm != nullptr) { - // NOTE(mrry): We must wrap `done` with code that deregisters the - // cancellation callback before calling the `done` callback, because the - // cancellation manager may no longer be live after `done` is called. - queue->push_back(new Item( - recv_args, - [cm, token, done = std::move(done)]( - const Status& s, const Args& send_args, const Args& recv_args, - const Tensor& v, bool dead) { - cm->TryDeregisterCallback(token); - done(s, send_args, recv_args, v, dead); - }, - token)); - } else { - queue->push_back(new Item(recv_args, std::move(done), token)); - } - - mu_.unlock(); - return; - } - - DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). "; - // A message has already arrived and is queued in the table under - // this key. Consumes the message and invokes the done closure. - Item* item = queue->head; - - // Delete the queue when the last element has been consumed. - if (item->next == nullptr) { - DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; - table_.erase(key_hash); - } else { - queue->head = item->next; - } - mu_.unlock(); - - // Invoke done() without holding the table lock. - DCHECK_EQ(item->type, Item::kSend); - done(Status::OK(), item->args, recv_args, *item->send_state.value, - item->send_state.is_dead); - delete item; + impl_.RecvAsync(key, recv_args, std::move(done)); } - void StartAbort(const Status& status) override { - CHECK(!status.ok()); - Table table; - { - mutex_lock l(mu_); - status_.Update(status); - table_.swap(table); - } - for (auto& p : table) { - Item* item = p.second.head; - while (item != nullptr) { - if (item->type == Item::kRecv) { - (*item->recv_state.waiter)(status, Args(), Args(), Tensor(), false); - } - Item* to_delete = item; - item = item->next; - delete to_delete; - } - } - } + void StartAbort(const Status& status) override { impl_.StartAbort(status); } private: - typedef LocalRendezvousImpl ME; + LocalRendezvous impl_; - // Represents a blocked Send() or Recv() call in the rendezvous. - struct Item { - enum Type { kSend = 0, kRecv = 1 }; - - Item(Args send_args, const Tensor& value, bool is_dead) - : Item(send_args, kSend) { - send_state.value.Init(value); - send_state.is_dead = is_dead; - } - - Item(Args recv_args, DoneCallback waiter, - CancellationToken cancellation_token) - : Item(recv_args, kRecv) { - recv_state.waiter.Init(std::move(waiter)); - recv_state.cancellation_token = cancellation_token; - } - - ~Item() { - if (args.device_context) { - args.device_context->Unref(); - } - if (type == kSend) { - send_state.value.Destroy(); - } else { - recv_state.waiter.Destroy(); - } - } - - const Args args; - const Type type; - - // Link to next item in an ItemQueue. - Item* next = nullptr; - - // The validity of `send_state` or `recv_state` is determined by `type == - // kSend` or `type == kRecv` respectively. - union { - struct { - ManualConstructor value; - bool is_dead; - } send_state; - struct { - ManualConstructor waiter; - CancellationToken cancellation_token; - } recv_state; - }; - - private: - Item(Args args, Type type) : args(args), type(type) { - if (args.device_context) { - args.device_context->Ref(); - } - } - }; - - // We key the hash table by KeyHash of the Rendezvous::CreateKey string - static uint64 KeyHash(const StringPiece& k) { - return Hash64(k.data(), k.size()); - } - - // By invariant, the item queue under each key is of the form - // [item.type == kSend]* meaning each item is a sent message. - // or - // [item.type == kRecv]* meaning each item is a waiter. - struct ItemQueue { - void push_back(Item* item) { - if (TF_PREDICT_TRUE(head == nullptr)) { - // The queue is empty. - head = item; - tail = item; - } else { - DCHECK_EQ(tail->type, item->type); - tail->next = item; - tail = item; - } - } - - Item* head = nullptr; - Item* tail = nullptr; - }; - typedef gtl::FlatMap Table; - - // TODO(zhifengc): shard table_. - mutex mu_; - Table table_ GUARDED_BY(mu_); - Status status_ GUARDED_BY(mu_); - - ~LocalRendezvousImpl() override { - if (!table_.empty()) { - StartAbort(errors::Cancelled("LocalRendezvousImpl deleted")); - } - } - - TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl); + TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousWrapper); }; } // namespace -Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); } +Rendezvous* NewLocalRendezvous() { return new LocalRendezvousWrapper; } } // end namespace tensorflow From 452aacf6448b4a020970a0dac0c6ba19c8717c97 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 27 Nov 2019 11:43:36 -0800 Subject: [PATCH 052/279] Use explicit inference priorities instead of setting just allow_precision_loss PiperOrigin-RevId: 282808216 Change-Id: If7f79075b8b4107e868d8277d590174a95d758f6 --- tensorflow/lite/tools/evaluation/utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc index ef48aabc399..290e7549908 100644 --- a/tensorflow/lite/tools/evaluation/utils.cc +++ b/tensorflow/lite/tools/evaluation/utils.cc @@ -121,7 +121,7 @@ Interpreter::TfLiteDelegatePtr CreateGPUDelegate( tflite::FlatBufferModel* model) { #if defined(__ANDROID__) TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); - options.is_precision_loss_allowed = 1; + options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY; options.inference_preference = TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED; From 8ef8c28a21690f31de3ed48f0e269526a36dee9d Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Wed, 27 Nov 2019 11:55:44 -0800 Subject: [PATCH 053/279] Add temp/intermediate tensors to interpreter state dump This will make it easier to debug memory-related issues with intermediate tensors. PiperOrigin-RevId: 282810295 Change-Id: Ibc3797884ffeccd23feb16ed22ed4fd760a6355f --- tensorflow/lite/optional_debug_tools.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc index 2d33f79c1f0..4e9b7d4e0a4 100644 --- a/tensorflow/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -115,6 +115,14 @@ void PrintInterpreterState(Interpreter* interpreter) { PrintTfLiteIntVector(node.inputs); printf(" Outputs:"); PrintTfLiteIntVector(node.outputs); + if (node.intermediates && node.intermediates->size) { + printf(" Intermediates:"); + PrintTfLiteIntVector(node.intermediates); + } + if (node.temporaries && node.temporaries->size) { + printf(" Temporaries:"); + PrintTfLiteIntVector(node.temporaries); + } } } From 9fae7ec3734ce93a38c90baaea97701e3cdda969 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Wed, 27 Nov 2019 11:58:08 -0800 Subject: [PATCH 054/279] Fixed typo in Toy tutorial (second var e -> var f) PiperOrigin-RevId: 282810649 Change-Id: Ic3bd523c99ad411bc79cea59abb8b98a694d7eaf --- third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md index 650a135619e..b8beff8d3f5 100644 --- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md +++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md @@ -145,7 +145,7 @@ Module: var: b @test/ast.toy:21:30 var: c @test/ast.toy:21:33 ] - VarDecl e<> @test/ast.toy:24:3 + VarDecl f<> @test/ast.toy:24:3 Call 'multiply_transpose' [ @test/ast.toy:24:11 Call 'transpose' [ @test/ast.toy:24:30 var: a @test/ast.toy:24:40 From bbd8a3e0e489f7ac138075588cebc3011e43d675 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 27 Nov 2019 12:05:29 -0800 Subject: [PATCH 055/279] Fixed an internal bug where width/height were consistently (and incorrectly) mapped to y/x vs the correct x/y. PiperOrigin-RevId: 282812176 Change-Id: I5df8fa023a8cc873a3cd6af8e012a0460baec8ee --- .../grappler/costs/op_level_cost_estimator.cc | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 8e60d391cd8..751bf952213 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -684,28 +684,28 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs( int x_index, y_index, channel_index; const string& data_format = GetDataFormat(op_info); if (data_format == "NCHW") { - x_index = 2; - y_index = 3; channel_index = 1; + y_index = 2; + x_index = 3; } else { // Use NHWC. - x_index = 1; - y_index = 2; + y_index = 1; + x_index = 2; channel_index = 3; } const string& filter_format = GetFilterFormat(op_info); int filter_x_index, filter_y_index, in_channel_index, out_channel_index; if (filter_format == "HWIO") { - filter_x_index = 0; - filter_y_index = 1; + filter_y_index = 0; + filter_x_index = 1; in_channel_index = 2; out_channel_index = 3; } else { // Use OIHW - filter_x_index = 2; - filter_y_index = 3; - in_channel_index = 1; out_channel_index = 0; + in_channel_index = 1; + filter_y_index = 2; + filter_x_index = 3; } int64 batch = image_shape.dim(0).size(); int64 ix = image_shape.dim(x_index).size(); @@ -1311,9 +1311,9 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( // TODO(varomodt): should we centralize the Conv2D input/output shapes? OpInfo::TensorProperties output; if (data_format == "NCHW") { - output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy}); + output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox}); } else if (data_format == "NHWC") { - output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); + output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz}); } // Add the operations the fused op always computes. @@ -1768,12 +1768,12 @@ OpLevelCostEstimator::OpDimensionsFromInputs( int x_index, y_index, channel_index; const string& data_format = GetDataFormat(op_info); if (data_format == "NCHW") { - x_index = 2; - y_index = 3; channel_index = 1; - } else { - x_index = 1; y_index = 2; + x_index = 3; + } else { + y_index = 1; + x_index = 2; channel_index = 3; } int64 batch = image_shape.dim(0).size(); From fe77f7925607fbbeae9bf1a40335aeb7ef71ba83 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Wed, 27 Nov 2019 12:29:01 -0800 Subject: [PATCH 056/279] [MLIR:TF/XLA] Side effect analysis. An analysis that infers control dependencies based on side-effects on known and unknown resources. Side-effecting ops on uknown resources are conservatively treated as interferencing all known resource op accesses. It distinguishes accesses based on whether they are read-only, and read-only ops do not interfer with each other. PiperOrigin-RevId: 282815679 Change-Id: I4f430c5b6cbfb02284c150c85b41f0b81458b6e9 --- tensorflow/compiler/mlir/tensorflow/BUILD | 18 + .../analysis/side_effect_analysis.cc | 374 ++++++++++++++++++ .../analysis/side_effect_analysis.h | 125 ++++++ .../tests/side-effect-analysis-test.mlir | 237 +++++++++++ .../transforms/test_side_effect_analysis.cc | 77 ++++ 5 files changed, 831 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 683c66eced3..7cfa802b1c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -217,6 +217,7 @@ cc_library( "transforms/shape_inference.cc", "transforms/shape_inference_pass.cc", "transforms/sink_constant.cc", + "transforms/test_side_effect_analysis.cc", "transforms/tpu_cluster_formation.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_rewrite_pass.cc", @@ -239,6 +240,7 @@ cc_library( ":error_util", ":export_tf_dialect_op", ":mangling_util", + ":side_effect_analysis", ":tensorflow", ":tensorflow_optimize_inc_gen", ":tpu_rewrite_device_util", @@ -981,3 +983,19 @@ cc_library( "@local_config_mlir//:Pass", ], ) + +cc_library( + name = "side_effect_analysis", + srcs = ["analysis/side_effect_analysis.cc"], + hdrs = ["analysis/side_effect_analysis.h"], + deps = [ + ":tensorflow", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/core:framework", + "@com_google_absl//absl/strings", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc new file mode 100644 index 00000000000..8d43c9330d0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -0,0 +1,374 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace mlir { +namespace TF { + +namespace { + +constexpr int64_t kUnknownResourceId = -1; + +// Returns if a VarHandleOp is anonymous, which means it always creates a new +// variable. +bool IsResourceHandleAnonymous(TF::VarHandleOp handle) { + return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME; +} + +// Returns a string unique identifier for a non-anonymous VarHandleOp. +std::string GetVarHandleStringId(TF::VarHandleOp handle) { + auto device = handle.getAttrOfType("device"); + return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(), + "/", device ? device.getValue().str() : std::string("")); +} + +// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always +// creates a new ID; otherwise, tries to reuse the existing ID for the +// referenced variable if it exists, or creates a new one if not. +int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id, + llvm::StringMap* name_id_map) { + // Always create a new ID for anonymous handle. + if (IsResourceHandleAnonymous(handle)) return (*next_id)++; + + auto name = GetVarHandleStringId(handle); + auto emplace_res = name_id_map->try_emplace(name, *next_id); + // New ID created, increment next_id. + if (emplace_res.second) ++(*next_id); + return emplace_res.first->second; +} + +} // namespace + +ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) { + auto func_op = llvm::dyn_cast(op); + if (!func_op) return; + AnalyzeFunction(func_op); +} + +void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { + // This function populates resource_value_to_ids_. + // + // TODO(yuanzx): Pass variable aliasing information to functions so we can + // properly resolve aliasing arguments. + // + // Before having that, we assume function arguments do not alias each other. + int64_t next_unique_id = 0; + for (auto arg : func_op.getArguments()) { + if (!mlir::getElementTypeOrSelf(arg->getType()).isa()) + continue; + resource_value_to_ids_[arg].insert(next_unique_id++); + } + llvm::StringMap var_handle_name_id_map; + auto forward_input_to_output = [&](Value* operand, Value* result) { + if (!mlir::getElementTypeOrSelf(result->getType()).isa()) + return; + auto operand_it = resource_value_to_ids_.find(operand); + assert(operand_it != resource_value_to_ids_.end() && + "A resource-type output does not have the corresponding " + "resource-type input."); + resource_value_to_ids_[result].insert(operand_it->getSecond().begin(), + operand_it->getSecond().end()); + }; + // TODO(yuanzx): Consider control-flow ops. + func_op.walk([&](Operation* op) { + if (auto var_handle = llvm::dyn_cast(op)) { + resource_value_to_ids_[var_handle.resource()].insert( + GetOrCreateIdForVarHandle(var_handle, &next_unique_id, + &var_handle_name_id_map)); + } else if (llvm::isa(op) || + llvm::isa(op)) { + for (auto operand_and_result : + llvm::zip(op->getOperands(), op->getResults())) { + forward_input_to_output(std::get<0>(operand_and_result), + std::get<1>(operand_and_result)); + } + } else { + for (auto result : op->getResults()) { + if (!mlir::getElementTypeOrSelf(result->getType()) + .isa()) + continue; + resource_value_to_ids_[result].insert(kUnknownResourceId); + } + } + }); +} + +bool ResourceAliasAnalysis::IsUnknownResource(const Value* resource) const { + auto it = resource_value_to_ids_.find(resource); + assert(it != resource_value_to_ids_.end() && !it->getSecond().empty()); + // The set is sorted so we only need to check the first element since + // kUnknownResourceId < 0. + static_assert(kUnknownResourceId < 0, + "kUnknownResourceId should be negative"); + return *it->getSecond().begin() == kUnknownResourceId; +} + +const llvm::SmallSet& ResourceAliasAnalysis::GetResourceUniqueIds( + const Value* resource) const { + auto it = resource_value_to_ids_.find(resource); + assert(it != resource_value_to_ids_.end() && "Unseen resource was queried"); + return it->getSecond(); +} + +namespace { + +// Returns a set that contains only kUnknownResourceId. +llvm::SmallDenseSet UnknownResourceSet() { + llvm::SmallDenseSet unknown_set; + unknown_set.insert(kUnknownResourceId); + return unknown_set; +} + +// Returns all resources that could be accessed by op, or UnknownResourceSet() +// if we cannot find all of them. +llvm::SmallDenseSet FindAccessedResources( + Operation* op, const ResourceAliasAnalysis& alias_analysis) { + llvm::SmallDenseSet resources; + + for (auto operand : op->getOperands()) { + if (!mlir::getElementTypeOrSelf(operand->getType()).isa()) + continue; + if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet(); + const auto& ids = alias_analysis.GetResourceUniqueIds(operand); + resources.insert(ids.begin(), ids.end()); + } + for (auto result : op->getResults()) { + if (!mlir::getElementTypeOrSelf(result->getType()).isa()) + continue; + if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet(); + const auto& ids = alias_analysis.GetResourceUniqueIds(result); + resources.insert(ids.begin(), ids.end()); + } + return resources; +} + +// Returns an XlaResourceOpInfo (or nullptr if it does not exist) that specifies +// the resource access type of the op. It tells whether the op is read only, +// etc. +// +// TODO(yuanzx): Define this information in a different place. Currently we use +// tensorflow/compiler/tf2xla/resource_operation_table.h. +const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) { + auto op_name = op->getName().getStringRef().str(); + if (op->getName().getDialect() != + TF::TensorFlowDialect::getDialectNamespace()) { + return nullptr; + } + return tensorflow::GetResourceOpInfoForOp( + op->getName().getStringRef().split('.').second.str()); +} + +// Returns whether `op` accesses resources and it is known to be read-only. +bool OpIsReadOnly(Operation* op) { + auto resource_op_info = GetResourceInfoForOp(op); + return resource_op_info && + resource_op_info->kind() == tensorflow::XlaResourceOpKind::kRead; +} + +// Returns if `op` is a resource declaration. +bool OpIsDeclaration(Operation* op, + const ResourceAliasAnalysis& alias_analysis) { + // TODO(yuanzx): Add other types of resources. + return llvm::isa(op) || + ((llvm::isa(op) || llvm::isa(op)) && + !FindAccessedResources(op, alias_analysis).empty()); +} + +} // namespace + +void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op, + bool read_only) { + if (resource_id == kUnknownResourceId) { + if (read_only) { + // New unknown read is not tracked by any known resource access. + for (auto& entry : per_resource_access_info_) { + entry.getSecond().tracked_last_unknown_read = false; + } + } else { + // Unknown write can clear all other tracked information, since it acts + // like a barrier. + per_resource_access_info_.clear(); + } + } + auto& info = per_resource_access_info_[resource_id]; + if (read_only) { + info.reads_since_last_write.push_back(op); + // Resource read must have carried control dependencies of unknown write. + info.tracked_last_unknown_write = true; + } else { + // Resource write must have carried control dependencies of unknown access. + info.tracked_last_unknown_write = true; + info.tracked_last_unknown_read = true; + info.last_write = op; + info.reads_since_last_write.clear(); + } +} + +void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id, + Operation* op, + bool read_only) { + auto it = per_resource_access_info_.find(resource_id); + if (it == per_resource_access_info_.end()) return; + const auto& access_info = it->getSecond(); + auto& control_predecessors = control_predecessors_[op]; + bool read_tracked = false; + if (!read_only) { + control_predecessors.insert(access_info.reads_since_last_write.begin(), + access_info.reads_since_last_write.end()); + read_tracked = !access_info.reads_since_last_write.empty(); + } + if (access_info.last_write && !read_tracked) { + control_predecessors.insert(access_info.last_write); + } +} + +void SideEffectAnalysis::AnalyzeFunction( + FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) { + // This function populates control_predecessors_ and control_successors_ by + // walking through func_op's body, and tracking resource accesses in + // per_resource_access_info_. + + // Returns whether an access to `resource` can skip control edges from + // prevoius accesses to unknown resources, due to that earlier accesses to + // `resource` already indirectly tracked previous accesses to uknown + // resources. `read_only` specifies the type of access of the current op being + // considered. + auto unknown_access_indirectly_tracked_by_resource = [&](int64_t resource, + bool read_only) { + auto it = per_resource_access_info_.find(resource); + if (it == per_resource_access_info_.end()) return false; + auto unknown_it = per_resource_access_info_.find(kUnknownResourceId); + const bool no_unknown_read = + unknown_it == per_resource_access_info_.end() || + unknown_it->getSecond().reads_since_last_write.empty(); + return read_only + ? it->second.tracked_last_unknown_write + : it->second.tracked_last_unknown_write && + (it->second.tracked_last_unknown_read || no_unknown_read); + }; + + func_op.walk([&](Operation* op) { + // We do not need explicit control edges for declaration ops. + if (OpIsDeclaration(op, alias_analysis)) return; + + auto resource_op_info = GetResourceInfoForOp(op); + if (!resource_op_info && op->hasNoSideEffect()) return; + + llvm::SmallDenseSet resources = + resource_op_info ? FindAccessedResources(op, alias_analysis) + : UnknownResourceSet(); + assert(!resources.empty()); + const bool is_unknown = resources.count(kUnknownResourceId) > 0; + const bool read_only = OpIsReadOnly(op); + bool indirectly_tracked_unknown_access = false; + // First add edges from known resources. + if (is_unknown) { + for (auto& entry : per_resource_access_info_) { + if (entry.getFirst() == kUnknownResourceId) continue; + AddPredecessorsForAccess(entry.getFirst(), op, read_only); + indirectly_tracked_unknown_access |= + unknown_access_indirectly_tracked_by_resource(entry.getFirst(), + read_only); + } + } else { + for (int64_t resource : resources) { + AddPredecessorsForAccess(resource, op, read_only); + indirectly_tracked_unknown_access |= + unknown_access_indirectly_tracked_by_resource(resource, read_only); + // Update access info for known resources. + TrackAccess(resource, op, read_only); + } + } + // If not indirectly tracked, add edges from the unknown resource. + if (!indirectly_tracked_unknown_access) { + AddPredecessorsForAccess(kUnknownResourceId, op, read_only); + } + if (is_unknown) { + // Update access info for unknown resource. + TrackAccess(kUnknownResourceId, op, read_only); + } + }); + + // Populate control_successors_ based on control_predecessors_. + for (auto& entry : control_predecessors_) { + auto op = entry.getFirst(); + for (auto predecessor : entry.getSecond()) { + control_successors_[predecessor].insert(op); + } + } +} + +llvm::SmallVector SideEffectAnalysis::DirectControlPredecessors( + Operation* op, llvm::function_ref filter) const { + llvm::SmallVector result; + auto it = control_predecessors_.find(op); + if (it == control_predecessors_.end()) return result; + result.reserve(it->getSecond().size()); + for (auto predecessor : it->getSecond()) { + if (!filter || filter(predecessor)) result.push_back(predecessor); + } + llvm::sort(result, + [](Operation* a, Operation* b) { return a->isBeforeInBlock(b); }); + return result; +} + +llvm::SmallVector SideEffectAnalysis::DirectControlSuccessors( + Operation* op, llvm::function_ref filter) const { + llvm::SmallVector result; + auto it = control_successors_.find(op); + if (it == control_successors_.end()) return result; + result.reserve(it->getSecond().size()); + for (auto successor : it->getSecond()) { + if (!filter || filter(successor)) result.push_back(successor); + } + llvm::sort(result, + [](Operation* a, Operation* b) { return a->isBeforeInBlock(b); }); + return result; +} + +SideEffectAnalysis::SideEffectAnalysis(Operation* op) { + auto func_op = llvm::dyn_cast(op); + if (!func_op) return; + ResourceAliasAnalysis alias_analysis(op); + AnalyzeFunction(func_op, alias_analysis); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h new file mode 100644 index 00000000000..5eee28a6ae0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -0,0 +1,125 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ + +#include +#include + +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Region.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir + +namespace mlir { +namespace TF { + +// An analysis that runs on a function and maps each resource-type value to a +// set of unique int64_t IDs representing the possible resources it could alias. +class ResourceAliasAnalysis { + public: + explicit ResourceAliasAnalysis(Operation* op); + ~ResourceAliasAnalysis() = default; + ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default; + + // Returns if the analysis fails to resolve a resource-type value. + bool IsUnknownResource(const Value* resource) const; + + // Returns the set unique IDs which `resource` could alias. Requires that + // IsUnknownResource(resource) == true. + const llvm::SmallSet& GetResourceUniqueIds( + const Value* resource) const; + + private: + ResourceAliasAnalysis() = default; + + // Runs the analysis on `func_op` and populates resource_value_to_ids_. + void AnalyzeFunction(FuncOp func_op); + + // Maps each resource-type value to a set of unique IDs that it could alias. + llvm::SmallDenseMap, 8> + resource_value_to_ids_; +}; + +// An analysis that runs on a function and infers the control predecessors and +// successors for each op, based on side-effects on known and unknown resources. +// Side-effecting ops on uknown resources are conservatively treated as +// interfering with all known resource op accesses. It distinguishes accesses +// based on whether they are read-only, and read-only ops do not interfer with +// each other. +class SideEffectAnalysis { + public: + explicit SideEffectAnalysis(Operation* op); + SideEffectAnalysis(SideEffectAnalysis&& other) = default; + ~SideEffectAnalysis() = default; + + // Returns a vector of ops that are direct control predecessors of `op`, + // sorted in program order. If `filter` is provided, only predecessors that + // pass the filter (returning true) will be included. + llvm::SmallVector DirectControlPredecessors( + Operation* op, + llvm::function_ref filter = nullptr) const; + + // Returns a vector of ops that are direct control successors of `op`, sorted + // in program order. If `filter` is provided, only successors that pass the + // filter (returning true) will be included. + llvm::SmallVector DirectControlSuccessors( + Operation* op, + llvm::function_ref filter = nullptr) const; + + private: + // Runs the analysis on `func_op` and populates control_predecessors_ and + // control_successors_. + void AnalyzeFunction(FuncOp func_op, + const ResourceAliasAnalysis& alias_analysis); + + // Updates control_predecessors_ for `op` that is being visted, on the given + // `resource_id`. + void AddPredecessorsForAccess(int64_t resource_id, Operation* op, + bool read_only); + + // Adds op's access to per_resource_access_info_. + void TrackAccess(int64_t resource_id, Operation* op, bool read_only); + + // Maps from an op to its control predecessors. + llvm::SmallDenseMap, 8> + control_predecessors_; + // Maps from an op to its control successors. + llvm::SmallDenseMap, 8> + control_successors_; + + // Internal per-resource data structure when we build the dependencies. + struct PerResourceAcessInfo { + // Last op that writes the resource before the current op being analyzed. + Operation* last_write = nullptr; + // Read ops since last_write before the current op being analyzed. + llvm::SmallVector reads_since_last_write; + // Whether previous accesses of this resource already tracked last unknown + // read/write. + bool tracked_last_unknown_read = false; + bool tracked_last_unknown_write = false; + }; + llvm::SmallDenseMap + per_resource_access_info_; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir new file mode 100644 index 00000000000..c6eb4663e57 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -0,0 +1,237 @@ +// RUN: tf-opt -split-input-file -tf-test-side-effect-analysis -verify-diagnostics %s | FileCheck %s --dump-input=fail + +// Tests that the pass tracks control dependencies for reads/writes on the same +// resource. + +// CHECK-LABEL: func @non_aliasing_reads_writes +func @non_aliasing_reads_writes( +// expected-remark@above {{ID: 13}} +// expected-remark@above {{Predecessors: {12}}} + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor<32xf32>) -> (tensor<32xf32>) { + %graph = tf_executor.graph { + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Predecessors: {10}}} + // expected-remark@above {{Successors: {12}}} + // CHECK: tf_executor.island + %island:2 = tf_executor.island { + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + // expected-remark@above {{Successors: {10}}} + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {1}}} + "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {6}}} + %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Successors: {5}}} + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 3}} + %read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Successors: {8}}} + "tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {2}}} + // expected-remark@above {{Successors: {8}}} + "tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {1}}} + // expected-remark@above {{Successors: {7}}} + %read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {6}}} + // expected-remark@above {{Successors: {8}}} + tf_executor.yield %read3 : tensor<32xf32> + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {4,5,7}}} + // expected-remark@above {{Successors: {9}}} + } + tf_executor.fetch %island#0 : tensor<32xf32> + // expected-remark@above {{ID: 10}} + // expected-remark@above {{Predecessors: {9}}} + // expected-remark@above {{Successors: {11}}} + } + return %graph : tensor<32xf32> + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {11}}} + // expected-remark@above {{Successors: {13}}} +} + +// ----- + +// Tests that the pass tracks control dependencies for reads/writes on the two +// resource handles that refer to the same variable. + +// CHECK-LABEL: func @aliasing_reads_writes +func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () { +// expected-remark@above {{ID: 14}} +// expected-remark@above {{Predecessors: {13}}} + tf_executor.graph { + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {11}}} + // expected-remark@above {{Successors: {13}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 10}} + // expected-remark@above {{Predecessors: {9}}} + // expected-remark@above {{Successors: {11}}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 0}} + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 1}} + %vh1_id:2 = "tf.IdentityN"(%vh1, %arg0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> (tensor<*x!tf.resource>>, tensor<32xf32>) + // expected-remark@above {{ID: 2}} + %read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + "tf.AssignVariableOp"(%vh1_id#0, %arg0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + // expected-remark@above {{Successors: {5,6}}} + %read1 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {7}}} + %read2 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {7}}} + "tf.AssignVariableOp"(%vh0, %read2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {5,6}}} + // expected-remark@above {{Successors: {8}}} + "tf.AssignVariableOp"(%vh1_id#0, %read1) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {7}}} + // expected-remark@above {{Successors: {9}}} + tf_executor.yield + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + // expected-remark@above {{Successors: {10}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Predecessors: {10}}} + // expected-remark@above {{Successors: {12}}} + } + return + // expected-remark@above {{ID: 13}} + // expected-remark@above {{Predecessors: {12}}} + // expected-remark@above {{Successors: {14}}} +} + +// ----- + +// Tests that the pass tracks control dependencies for side-effecting on unknown +// resources. + +// CHECK-LABEL: func @unknown_side_effecting_op +func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () { +// expected-remark@above {{ID: 13}} +// expected-remark@above {{Predecessors: {12}}} + tf_executor.graph { + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Predecessors: {10}}} + // expected-remark@above {{Successors: {12}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + // expected-remark@above {{Successors: {10}}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 0}} + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 1}} + %read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Successors: {4}}} + "tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + "tf._UnknownSideEffectingOp_"() : () -> () + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {2,3}}} + // expected-remark@above {{Successors: {5,6}}} + %read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {7}}} + "tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {8}}} + "tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {5}}} + // expected-remark@above {{Successors: {8}}} + tf_executor.yield + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {6,7}}} + // expected-remark@above {{Successors: {9}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 10}} + // expected-remark@above {{Predecessors: {9}}} + // expected-remark@above {{Successors: {11}}} + } + return + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {11}}} + // expected-remark@above {{Successors: {13}}} +} + +// ----- + +// Tests that the pass tracks control dependencies for read-only ops on unknown +// resources. + +// CHECK-LABEL: func @read_only_unknown_resource +func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () { +// expected-remark@above {{ID: 10}} +// expected-remark@above {{Predecessors: {9}}} + tf_executor.graph { + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {7}}} + // expected-remark@above {{Successors: {9}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {5}}} + // expected-remark@above {{Successors: {7}}} + %vh0 = "tf._UnknownSideEffectingOp_"() : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {2,3}}} + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 1}} + %read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {4}}} + %read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {4}}} + "tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {2,3}}} + // expected-remark@above {{Successors: {5}}} + tf_executor.yield + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {6}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {6}}} + // expected-remark@above {{Successors: {8}}} + } + return + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + // expected-remark@above {{Successors: {10}}} +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc new file mode 100644 index 00000000000..f0b7964389d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" + +namespace mlir { +namespace tf_executor { + +namespace { + +// A pass that adds "Predecessors" and "Successors" remarks for each op based on +// SideEffectAnalysis result. For testing purpose only. +struct TestSideEffectAnalysis + : public mlir::FunctionPass { + void runOnFunction() override { + int64_t next_id = 0; + llvm::SmallDenseMap ids; + getFunction().walk([&](Operation* op) { + ids[op] = next_id++; + op->emitRemark("ID: ") << ids[op]; + }); + auto join_ids = [&](const llvm::ArrayRef ops) { + llvm::SmallVector id_vec; + id_vec.reserve(ops.size()); + for (auto op : ops) id_vec.push_back(std::to_string(ids[op])); + return llvm::join(id_vec, ","); + }; + auto& analysis = getAnalysis(); + getFunction().walk([&](Operation* op) { + if (!analysis.DirectControlPredecessors(op).empty()) { + op->emitRemark("Predecessors: ") + << "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}"; + } + if (!analysis.DirectControlSuccessors(op).empty()) { + op->emitRemark("Successors: ") + << "{" << join_ids(analysis.DirectControlSuccessors(op)) << "}"; + } + }); + } +}; + +static mlir::PassRegistration pass( + "tf-test-side-effect-analysis", + "Add remarks based on side-effect analysis result, for testing purpose."); + +} // anonymous namespace + +} // namespace tf_executor +} // namespace mlir From 5d55514c43db082ccef44354470473258692c259 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 27 Nov 2019 12:42:25 -0800 Subject: [PATCH 057/279] Add CPU benchmarks for Conv2D gradients PiperOrigin-RevId: 282817699 Change-Id: Ib658159605dec1ed653b53db9422742971868101 --- .../conv_grad_filter_ops_benchmark_test.cc | 110 ++++++++++-------- .../conv_grad_input_ops_benchmark_test.cc | 98 +++++++++------- 2 files changed, 112 insertions(+), 96 deletions(-) diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc index bb6eb846408..9b168045047 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc @@ -38,20 +38,21 @@ static Tensor MakeRandomTensor(const TensorShape& shape) { template static Graph* Conv2DBackpropFilter(int batch, int height, int width, - int in_depth, int filter_w, int filter_h, - int out_depth, TensorFormat data_format) { + int in_depth, int filter_h, int filter_w, + int out_depth, int stride_h, int stride_w, + TensorFormat data_format) { auto* graph = new Graph(OpRegistry::Global()); Tensor input_t = data_format == FORMAT_NHWC ? MakeRandomTensor({batch, height, width, in_depth}) : MakeRandomTensor({batch, in_depth, height, width}); Tensor filter_t = - MakeRandomTensor({filter_w, filter_h, in_depth, out_depth}); + MakeRandomTensor({filter_h, filter_w, in_depth, out_depth}); // Compute dimensions for the `out_backprop` tensor. Conv2DParameters params; params.dilations = {1, 1, 1, 1}; - params.strides = {1, 1, 1, 1}; + params.strides = {1, stride_h, stride_w, 1}; params.padding = Padding::SAME; params.data_format = data_format; @@ -83,7 +84,7 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width, .Input(filter_dims) .Input(backprop) .Attr("T", DataTypeToEnum::value) - .Attr("strides", {1, 1, 1, 1}) + .Attr("strides", {1, stride_h, stride_w, 1}) .Attr("padding", "SAME") .Attr("data_format", ToString(data_format)) .Finalize(graph, &conv2d)); @@ -91,12 +92,6 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width, return graph; } -// -------------------------------------------------------------------------- // -// The following benchmarks are used to compare different data format -// performance for different data types. They make sense only when CUDA enabled, -// because on CPU we only support data in NHWC. -// -------------------------------------------------------------------------- // - // Macro arguments names: --------------------------------------------------- // // T: data type // FORMAT: data format (NHWC or NCHW) @@ -107,57 +102,70 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width, // FC: filter count // FH: filter height // FW: filter width +// SH: stride height +// SW: stride width -#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FW, FH, FC) \ - name##_##T##_##FORMAT##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC +#define BM_CONCAT(a, b) a##_##b -#define BM_Conv2DBwdFilterFmt(T, FORMAT, N, H, W, C, FW, FH, FC, type) \ - static void BM_NAME(BM_Conv2DBackpropFilter, type, T, FORMAT, N, H, W, C, \ - FW, FH, FC)(int iters) { \ - testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) * \ - (C)); \ - test::Benchmark(#type, Conv2DBackpropFilter(N, H, W, C, FW, FH, FC, \ - FORMAT_##FORMAT)) \ - .Run(iters); \ - } \ - BENCHMARK(BM_NAME(BM_Conv2DBackpropFilter, type, T, FORMAT, N, H, W, C, FW, \ - FH, FC)); +#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FH, FW, FC, SH, SW) \ + BM_CONCAT(name##_##T##_##FORMAT##_##type##_in##N##x##H##x##W##x##C, \ + f##FH##x##FW##x##FC##_##s##SH##x##SW) + +#define BM_Conv2DBwdFilterFmt(T, FORMAT, N, H, W, C, FH, FW, FC, SH, SW, type) \ + static void BM_NAME(BM_Conv2DBackpropFilter, type, T, FORMAT, N, H, W, C, \ + FH, FW, FC, SH, SW)(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) * \ + (C)); \ + test::Benchmark(#type, Conv2DBackpropFilter(N, H, W, C, FH, FW, FC, SH, \ + SW, FORMAT_##FORMAT)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_NAME(BM_Conv2DBackpropFilter, type, T, FORMAT, N, H, W, C, FH, \ + FW, FC, SH, SW)); + +// ResNet50-ish convolutions. +#define BENCHMARK_DTYPE(FORMAT, BATCH, T, D) \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, D); \ + \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, D); \ + \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, D); \ + BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, D); -#if GOOGLE_CUDA using fp32 = float; using fp16 = Eigen::half; -// ResNet50-ish convolutions. -#define BENCHMARK_DTYPE(FORMAT, BATCH, T) \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, gpu); \ - \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, gpu); \ - \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, gpu); +BENCHMARK_DTYPE(NHWC, 8, fp32, cpu); +BENCHMARK_DTYPE(NHWC, 16, fp32, cpu); +BENCHMARK_DTYPE(NHWC, 32, fp32, cpu); -BENCHMARK_DTYPE(NHWC, 32, fp32); -BENCHMARK_DTYPE(NCHW, 32, fp32); +#if GOOGLE_CUDA +// -------------------------------------------------------------------------- // +// The following benchmarks are used to compare different data format +// performance for different data types. They make sense only when CUDA enabled, +// because on CPU we only support data in NHWC. +// -------------------------------------------------------------------------- // -BENCHMARK_DTYPE(NHWC, 32, fp16); -BENCHMARK_DTYPE(NCHW, 32, fp16); +BENCHMARK_DTYPE(NHWC, 32, fp32, gpu); +BENCHMARK_DTYPE(NCHW, 32, fp32, gpu); -BENCHMARK_DTYPE(NHWC, 64, fp32); -BENCHMARK_DTYPE(NCHW, 64, fp32); +BENCHMARK_DTYPE(NHWC, 32, fp16, gpu); +BENCHMARK_DTYPE(NCHW, 32, fp16, gpu); -BENCHMARK_DTYPE(NHWC, 64, fp16); -BENCHMARK_DTYPE(NCHW, 64, fp16); +BENCHMARK_DTYPE(NHWC, 64, fp32, gpu); +BENCHMARK_DTYPE(NCHW, 64, fp32, gpu); + +BENCHMARK_DTYPE(NHWC, 64, fp16, gpu); +BENCHMARK_DTYPE(NCHW, 64, fp16, gpu); #endif // GOOGLE_CUDA -BM_Conv2DBwdFilterFmt(float, NHWC, 8, 32, 32, 128, 1, 1, 128, cpu); - } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc index 938ef976ed8..70a08b2496c 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc @@ -38,8 +38,9 @@ static Tensor MakeRandomTensor(const TensorShape& shape) { template static Graph* Conv2DBackpropInput(int batch, int height, int width, - int in_depth, int filter_w, int filter_h, - int out_depth, TensorFormat data_format) { + int in_depth, int filter_h, int filter_w, + int out_depth, int stride_h, int stride_w, + TensorFormat data_format) { auto* graph = new Graph(OpRegistry::Global()); Tensor input_t = data_format == FORMAT_NHWC @@ -51,7 +52,7 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width, // Compute dimensions for the `out_backprop` tensor. Conv2DParameters params; params.dilations = {1, 1, 1, 1}; - params.strides = {1, 1, 1, 1}; + params.strides = {1, stride_h, stride_w, 1}; params.padding = Padding::SAME; params.data_format = data_format; @@ -83,7 +84,7 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width, .Input(filter) .Input(backprop) .Attr("T", DataTypeToEnum::value) - .Attr("strides", {1, 1, 1, 1}) + .Attr("strides", {1, stride_h, stride_w, 1}) .Attr("padding", "SAME") .Attr("data_format", ToString(data_format)) .Finalize(graph, &conv2d)); @@ -91,12 +92,6 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width, return graph; } -// -------------------------------------------------------------------------- // -// The following benchmarks are used to compare different data format -// performance for different data types. They make sense only when CUDA enabled, -// because on CPU we only support data in NHWC. -// -------------------------------------------------------------------------- // - // Macro arguments names: --------------------------------------------------- // // T: data type // FORMAT: data format (NHWC or NCHW) @@ -107,57 +102,70 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width, // FC: filter count // FH: filter height // FW: filter width +// SH: stride height +// SW: stride width -#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FW, FH, FC) \ - name##_##T##_##FORMAT##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC +#define BM_CONCAT(a, b) a##_##b -#define BM_Conv2DBwdInputFmt(T, FORMAT, N, H, W, C, FW, FH, FC, type) \ - static void BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FW, \ - FH, FC)(int iters) { \ +#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FH, FW, FC, SH, SW) \ + BM_CONCAT(name##_##T##_##FORMAT##_##type##_in##N##x##H##x##W##x##C, \ + f##FH##x##FW##x##FC##_##s##SH##x##SW) + +#define BM_Conv2DBwdInputFmt(T, FORMAT, N, H, W, C, FW, FH, FC, SH, SW, type) \ + static void BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FH, \ + FW, FC, SH, SW)(int iters) { \ testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) * \ (C)); \ - test::Benchmark(#type, Conv2DBackpropInput(N, H, W, C, FW, FH, FC, \ - FORMAT_##FORMAT)) \ + test::Benchmark(#type, Conv2DBackpropInput(N, H, W, C, FH, FW, FC, SH, \ + SW, FORMAT_##FORMAT)) \ .Run(iters); \ } \ - BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FW, \ - FH, FC)); + BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FH, \ + FW, FC, SH, SW)); -#if GOOGLE_CUDA using fp32 = float; using fp16 = Eigen::half; // ResNet50-ish convolutions. -#define BENCHMARK_DTYPE(FORMAT, BATCH, T) \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, gpu); \ - \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, gpu); \ - \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, gpu); +#define BENCHMARK_DTYPE(FORMAT, BATCH, T, D) \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, D); \ + \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, D); \ + \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, D); \ + BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, D); -BENCHMARK_DTYPE(NHWC, 32, fp32); -BENCHMARK_DTYPE(NCHW, 32, fp32); +BENCHMARK_DTYPE(NHWC, 8, fp32, cpu); +BENCHMARK_DTYPE(NHWC, 16, fp32, cpu); +BENCHMARK_DTYPE(NHWC, 32, fp32, cpu); -BENCHMARK_DTYPE(NHWC, 32, fp16); -BENCHMARK_DTYPE(NCHW, 32, fp16); +#if GOOGLE_CUDA +// -------------------------------------------------------------------------- // +// The following benchmarks are used to compare different data format +// performance for different data types. They make sense only when CUDA enabled, +// because on CPU we only support data in NHWC. +// -------------------------------------------------------------------------- // -BENCHMARK_DTYPE(NHWC, 64, fp32); -BENCHMARK_DTYPE(NCHW, 64, fp32); +BENCHMARK_DTYPE(NHWC, 32, fp32, gpu); +BENCHMARK_DTYPE(NCHW, 32, fp32, gpu); -BENCHMARK_DTYPE(NHWC, 64, fp16); -BENCHMARK_DTYPE(NCHW, 64, fp16); +BENCHMARK_DTYPE(NHWC, 32, fp16, gpu); +BENCHMARK_DTYPE(NCHW, 32, fp16, gpu); + +BENCHMARK_DTYPE(NHWC, 64, fp32, gpu); +BENCHMARK_DTYPE(NCHW, 64, fp32, gpu); + +BENCHMARK_DTYPE(NHWC, 64, fp16, gpu); +BENCHMARK_DTYPE(NCHW, 64, fp16, gpu); #endif // GOOGLE_CUDA -BM_Conv2DBwdInputFmt(float, NHWC, 8, 32, 32, 128, 1, 1, 128, cpu); - } // namespace tensorflow From 904f7fea4ca06ed1f86377130ce9af217315f2f6 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Wed, 27 Nov 2019 13:34:01 -0800 Subject: [PATCH 058/279] Update op version map for tf 2.1 RC0. PiperOrigin-RevId: 282826021 Change-Id: If60097ccff777dae027600364561c3865af176fd --- tensorflow/lite/toco/tflite/op_version.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index 39258339e0e..a7a829e77e3 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -74,7 +74,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kCast, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 2}, "1.14.0"}, - {{OperatorType::kDepthToSpace, 1}, kPendingReleaseOpVersion}, + {{OperatorType::kDepthToSpace, 1}, "2.1.0"}, {{OperatorType::kFakeQuant, 1}, "1.5.0"}, {{OperatorType::kFakeQuant, 2}, "1.10.0"}, {{OperatorType::kFullyConnected, 1}, "1.5.0"}, @@ -82,7 +82,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kFullyConnected, 3}, "1.14.0"}, {{OperatorType::kFullyConnected, 4}, "1.14.0"}, {{OperatorType::kFullyConnected, 5}, "2.0.0"}, - {{OperatorType::kFullyConnected, 6}, kPendingReleaseOpVersion}, + {{OperatorType::kFullyConnected, 6}, "2.1.0"}, {{OperatorType::kGather, 1}, "1.6.0"}, {{OperatorType::kGather, 2}, "1.14.0"}, {{OperatorType::kGather, 3}, "1.15.0"}, @@ -145,7 +145,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kSplitV, 1}, "1.13.1"}, {{OperatorType::kStridedSlice, 1}, "1.6.0"}, {{OperatorType::kStridedSlice, 2}, "1.14.0"}, - {{OperatorType::kStridedSlice, 3}, kPendingReleaseOpVersion}, + {{OperatorType::kStridedSlice, 3}, "2.1.0"}, {{OperatorType::kTopK_V2, 1}, "1.7.0"}, {{OperatorType::kTopK_V2, 2}, "1.14.0"}, {{OperatorType::kArgMax, 1}, "1.9.0"}, @@ -205,7 +205,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kElu, 1}, "1.14.0"}, {{OperatorType::kRound, 1}, "1.14.0"}, {{OperatorType::kRelu, 1}, "1.5.0"}, - {{OperatorType::kRelu, 2}, kPendingReleaseOpVersion}, + {{OperatorType::kRelu, 2}, "2.1.0"}, {{OperatorType::kRelu1, 1}, "1.5.0"}, {{OperatorType::kPRelu, 1}, "1.8.0"}, {{OperatorType::kExp, 1}, "1.7.0"}, From 0699728bf9973c3e9429d1eaa83ec0b9c1c831fc Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Wed, 27 Nov 2019 13:42:27 -0800 Subject: [PATCH 059/279] Migrate the TFLite C API out of lite/experimental Follow-up work will involve introducing a package target that bundles the native shared library with all necessary headers. RELNOTES: Migrated the TFLite C inference API out of experimental into lite/c. PiperOrigin-RevId: 282827414 Change-Id: Ibbef3dee899576b770c9410d212a0eb4087fe710 --- tensorflow/lite/build_def.bzl | 2 + tensorflow/lite/c/BUILD | 123 +++- tensorflow/lite/c/README.md | 48 ++ tensorflow/lite/{experimental => }/c/c_api.cc | 5 +- tensorflow/lite/{experimental => }/c/c_api.h | 65 +- .../c/c_api_experimental.cc | 4 +- .../{experimental => }/c/c_api_experimental.h | 6 +- .../c/c_api_experimental_test.cc | 4 +- .../{experimental => }/c/c_api_internal.h | 8 +- .../lite/{experimental => }/c/c_api_test.cc | 2 +- .../{experimental => }/c/exported_symbols.lds | 0 .../{experimental => }/c/version_script.lds | 0 tensorflow/lite/experimental/c/BUILD | 120 ---- tensorflow/lite/experimental/c/README.md | 1 + tensorflow/lite/experimental/c/c_api_types.h | 673 ------------------ .../unity/TensorFlowLitePlugin/README.md | 11 +- tensorflow/lite/experimental/ios/BUILD.apple | 15 +- tensorflow/lite/experimental/objc/BUILD.apple | 4 +- .../Configs/TensorFlowLite.tulsigen | 2 +- .../objc/TensorFlowLiteObjC-nightly.podspec | 2 +- .../objc/TensorFlowLiteObjC.podspec | 2 +- .../objc/sources/TFLInterpreter.mm | 2 +- .../Configs/TensorFlowLite.tulsigen | 2 +- 23 files changed, 252 insertions(+), 849 deletions(-) create mode 100644 tensorflow/lite/c/README.md rename tensorflow/lite/{experimental => }/c/c_api.cc (97%) rename tensorflow/lite/{experimental => }/c/c_api.h (83%) rename tensorflow/lite/{experimental => }/c/c_api_experimental.cc (93%) rename tensorflow/lite/{experimental => }/c/c_api_experimental.h (92%) rename tensorflow/lite/{experimental => }/c/c_api_experimental_test.cc (94%) rename tensorflow/lite/{experimental => }/c/c_api_internal.h (91%) rename tensorflow/lite/{experimental => }/c/c_api_test.cc (99%) rename tensorflow/lite/{experimental => }/c/exported_symbols.lds (100%) rename tensorflow/lite/{experimental => }/c/version_script.lds (100%) delete mode 100644 tensorflow/lite/experimental/c/BUILD create mode 100644 tensorflow/lite/experimental/c/README.md delete mode 100644 tensorflow/lite/experimental/c/c_api_types.h diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 2fb3ea5e714..f37ab23a67a 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -153,6 +153,7 @@ def tflite_cc_shared_object( linkstatic = 1, deps = [], visibility = None, + per_os_targets = False, tags = None): """Builds a shared object for TFLite.""" tf_cc_shared_object( @@ -164,6 +165,7 @@ def tflite_cc_shared_object( deps = deps, visibility = visibility, tags = tags, + per_os_targets = per_os_targets, ) def tf_to_tflite(name, src, options, out): diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index 37b996c565c..629320370cb 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -1,8 +1,128 @@ +load( + "//tensorflow/lite:build_def.bzl", + "tflite_cc_shared_object", + "tflite_copts", +) + package( - default_visibility = ["//visibility:public"], + default_visibility = [":experimental"], licenses = ["notice"], # Apache 2.0 ) +package_group( + name = "experimental", + packages = [ + "//tensorflow/lite/...", + "//third_party/dart/tflite_native/...", # whitelisted + ], +) + +# Generates a platform-specific shared library containing the TensorFlow Lite C +# API implementation as define in `c_api.h`. The exact output library name +# is platform dependent: +# - Linux/Android: `libtensorflowlite_c.so` +# - Mac: `libtensorflowlite_c.dylib` +# - Windows: `tensorflowlite_c.dll` +tflite_cc_shared_object( + name = "tensorflowlite_c", + linkopts = select({ + "//tensorflow:macos": [ + "-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)", + ], + "//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-Wl,--version-script,$(location //tensorflow/lite/c:version_script.lds)", + ], + }), + per_os_targets = True, + deps = [ + ":c_api", + ":c_api_experimental", + ":exported_symbols.lds", + ":version_script.lds", + ], +) + +cc_library( + name = "c_api_internal", + srcs = [ + "c_api.h", + "common.h", + ], + hdrs = ["c_api_internal.h"], + copts = tflite_copts(), + visibility = ["//visibility:private"], + deps = [ + ":common", + "//tensorflow/lite:framework", + ], +) + +cc_library( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = [ + "c_api.h", + "common.h", + ], + copts = tflite_copts(), + visibility = [ + ":experimental", + ], + deps = [ + ":c_api_internal", + ":common", + "//tensorflow/lite:framework", + "//tensorflow/lite:version", + "//tensorflow/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +cc_library( + name = "c_api_experimental", + srcs = ["c_api_experimental.cc"], + hdrs = ["c_api_experimental.h"], + copts = tflite_copts(), + deps = [ + ":c_api", + ":c_api_internal", + "//tensorflow/lite:kernel_api", + ], + alwayslink = 1, +) + +cc_test( + name = "c_api_test", + size = "small", + srcs = ["c_api_test.cc"], + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/add_quantized.bin", + ], + deps = [ + ":c_api", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = ["c_api_experimental_test.cc"], + data = ["//tensorflow/lite:testdata/add.bin"], + deps = [ + ":c_api", + ":c_api_experimental", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "common", srcs = ["common.c"], @@ -13,6 +133,7 @@ cc_library( visibility = [ "//tensorflow/lite:__subpackages__", ], + alwayslink = 1, ) # For use with library targets that can't use relative paths. diff --git a/tensorflow/lite/c/README.md b/tensorflow/lite/c/README.md new file mode 100644 index 00000000000..06579199393 --- /dev/null +++ b/tensorflow/lite/c/README.md @@ -0,0 +1,48 @@ +# TensorFlow Lite C API + +This directory contains C APIs for TensorFlow Lite. This includes C APIs +for common types, like kernels and delegates, as well as an explicit C API +for inference. + +## Header summary + +Each public C header contains types and methods for specific uses: + +* `common.h` - Contains common C enums, types and methods used throughout + TensorFlow Lite. This includes everything from error codes, to the kernel + and delegate APIs. +* `builtin_op_data.h` - Contains op-specific data that is used for builtin + kernels. This should only be used when (re)implementing a builtin operator. +* `c_api.h` - Contains the TensorFlow Lite C API for inference. The + functionality here is largely equivalent (though a strict subset of) the + functionality provided by the C++ `Interpreter` API. +* `c_api_experimental.h` - Contains experimental C API methods for inference. + These methods are useful and usable, but aren't yet part of the stable API. + +## Using the C API + +See the [`c_api.h`](c_api.h) header for API usage details. + +## Building the C API + +A native shared library target that contains the C API for inference has been +provided. Assuming a working [bazel](https://bazel.build/versions/master/docs/install.html) +configuration, this can be built as follows: + +```sh +bazel build -c opt --cxxopt=--std=c++11 //tensorflow/lite/c:tensorflowlite_c +``` + +and for Android (replace `android_arm` with `android_arm64` for 64-bit), +assuming you've [configured your project for Android builds](../g3doc/guide/android.md): + +```sh +bazel build -c opt --cxxopt=--std=c++11 --config=android_arm \ + //tensorflow/lite/c:tensorflowlite_c +``` + +The generated shared library will be available in your +`bazel-bin/tensorflow/lite/c` directory. A target which packages the shared +library together with the necessary headers (`c_api.h`, `c_api_experimental.h` +and `common.h`) will be available soon, and will also be released as a prebuilt +archive (together with existing prebuilt packages for Android/iOS). diff --git a/tensorflow/lite/experimental/c/c_api.cc b/tensorflow/lite/c/c_api.cc similarity index 97% rename from tensorflow/lite/experimental/c/c_api.cc rename to tensorflow/lite/c/c_api.cc index ab3ee961bb1..7ceddab4ecf 100644 --- a/tensorflow/lite/experimental/c/c_api.cc +++ b/tensorflow/lite/c/c_api.cc @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/c/c_api.h" +#include "tensorflow/lite/c/c_api.h" #include +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/error_reporter.h" -#include "tensorflow/lite/experimental/c/c_api_internal.h" -#include "tensorflow/lite/experimental/c/c_api_types.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" diff --git a/tensorflow/lite/experimental/c/c_api.h b/tensorflow/lite/c/c_api.h similarity index 83% rename from tensorflow/lite/experimental/c/c_api.h rename to tensorflow/lite/c/c_api.h index 09a045b1f2a..036df27b5d1 100644 --- a/tensorflow/lite/experimental/c/c_api.h +++ b/tensorflow/lite/c/c_api.h @@ -12,28 +12,59 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_ +#ifndef TENSORFLOW_LITE_C_C_API_H_ +#define TENSORFLOW_LITE_C_C_API_H_ #include #include -// Eventually the various C APIs defined in context.h will be migrated into -// the appropriate /c/c_api*.h header. For now, we pull in existing definitions -// for convenience. -#include "c_api_types.h" +#include "common.h" // -------------------------------------------------------------------------- -// Experimental C API for TensorFlowLite. -// -// The API leans towards simplicity and uniformity instead of convenience, as -// most usage will be by language-specific wrappers. -// -// Conventions: -// * We use the prefix TfLite for everything in the API. -// * size_t is used to represent byte sizes of objects that are -// materialized in the address space of the calling process. -// * int is used as an index into arrays. +/// C API for TensorFlow Lite. +/// +/// The API leans towards simplicity and uniformity instead of convenience, as +/// most usage will be by language-specific wrappers. It provides largely the +/// same set of functionality as that of the C++ TensorFlow Lite `Interpreter` +/// API, but is useful for shared libraries where having a stable ABI boundary +/// is important. +/// +/// Conventions: +/// * We use the prefix TfLite for everything in the API. +/// * size_t is used to represent byte sizes of objects that are +/// materialized in the address space of the calling process. +/// * int is used as an index into arrays. +/// +/// Usage: +///

+/// // Create the model and interpreter options.
+/// TfLiteModel* model = TfLiteModelCreateFromFile("/path/to/model.tflite");
+/// TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
+/// TfLiteInterpreterOptionsSetNumThreads(options, 2);
+///
+/// // Create the interpreter.
+/// TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
+///
+/// // Allocate tensors and populate the input tensor data.
+/// TfLiteInterpreterAllocateTensors(interpreter);
+/// TfLiteTensor* input_tensor =
+///     TfLiteInterpreterGetInputTensor(interpreter, 0);
+/// TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
+///                            input.size() * sizeof(float));
+///
+/// // Execute inference.
+/// TfLiteInterpreterInvoke(interpreter);
+///
+/// // Extract the output tensor data.
+/// TfLiteTensor* output_tensor =
+//      TfLiteInterpreterGetInputTensor(interpreter, 0);
+/// TfLiteTensorCopyToBuffer(output_tensor, output.data(),
+///                          output.size() * sizeof(float));
+///
+/// // Dispose of the model and interpreter objects.
+/// TfLiteInterpreterDelete(interpreter);
+/// TfLiteInterpreterOptionsDelete(options);
+/// TfLiteModelDelete(model);
 
 #ifdef SWIG
 #define TFL_CAPI_EXPORT
@@ -235,4 +266,4 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteTensorCopyToBuffer(
 }  // extern "C"
 #endif  // __cplusplus
 
-#endif  // TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_
+#endif  // TENSORFLOW_LITE_C_C_API_H_
diff --git a/tensorflow/lite/experimental/c/c_api_experimental.cc b/tensorflow/lite/c/c_api_experimental.cc
similarity index 93%
rename from tensorflow/lite/experimental/c/c_api_experimental.cc
rename to tensorflow/lite/c/c_api_experimental.cc
index 5bc305ef64b..4b812172937 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/lite/c/c_api_experimental.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/experimental/c/c_api_experimental.h"
+#include "tensorflow/lite/c/c_api_experimental.h"
 
-#include "tensorflow/lite/experimental/c/c_api_internal.h"
+#include "tensorflow/lite/c/c_api_internal.h"
 
 #ifdef __cplusplus
 extern "C" {
diff --git a/tensorflow/lite/experimental/c/c_api_experimental.h b/tensorflow/lite/c/c_api_experimental.h
similarity index 92%
rename from tensorflow/lite/experimental/c/c_api_experimental.h
rename to tensorflow/lite/c/c_api_experimental.h
index ce1a4a37293..a8f1a4294f5 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/lite/c/c_api_experimental.h
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
-#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
+#ifndef TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_
+#define TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_
 
 #include "tensorflow/lite/builtin_ops.h"
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/c_api.h"
 
 #ifdef __cplusplus
 extern "C" {
diff --git a/tensorflow/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/lite/c/c_api_experimental_test.cc
similarity index 94%
rename from tensorflow/lite/experimental/c/c_api_experimental_test.cc
rename to tensorflow/lite/c/c_api_experimental_test.cc
index 0d383998a29..ce72954774c 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/lite/c/c_api_experimental_test.cc
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/experimental/c/c_api_experimental.h"
+#include "tensorflow/lite/c/c_api_experimental.h"
 
 #include 
 #include "tensorflow/lite/builtin_ops.h"
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/c_api.h"
 #include "tensorflow/lite/testing/util.h"
 
 namespace {
diff --git a/tensorflow/lite/experimental/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h
similarity index 91%
rename from tensorflow/lite/experimental/c/c_api_internal.h
rename to tensorflow/lite/c/c_api_internal.h
index 8f5c301bc1d..474482d159a 100644
--- a/tensorflow/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/lite/c/c_api_internal.h
@@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_
-#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_
+#ifndef TENSORFLOW_LITE_C_C_API_INTERNAL_H_
+#define TENSORFLOW_LITE_C_C_API_INTERNAL_H_
 
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/op_resolver.h"
 
 // Internal structures used by the C API. These are likely to change and should
-// not be depended on.
+// not be depended on directly by any C API clients.
 //
 // NOTE: This header does not follow C conventions and does not define a C API.
 // It is effectively an (internal) implementation detail of the C API.
diff --git a/tensorflow/lite/experimental/c/c_api_test.cc b/tensorflow/lite/c/c_api_test.cc
similarity index 99%
rename from tensorflow/lite/experimental/c/c_api_test.cc
rename to tensorflow/lite/c/c_api_test.cc
index 8de0f414086..eb2a70f9f0b 100644
--- a/tensorflow/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/lite/c/c_api_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/c_api.h"
 
 #include 
 #include 
diff --git a/tensorflow/lite/experimental/c/exported_symbols.lds b/tensorflow/lite/c/exported_symbols.lds
similarity index 100%
rename from tensorflow/lite/experimental/c/exported_symbols.lds
rename to tensorflow/lite/c/exported_symbols.lds
diff --git a/tensorflow/lite/experimental/c/version_script.lds b/tensorflow/lite/c/version_script.lds
similarity index 100%
rename from tensorflow/lite/experimental/c/version_script.lds
rename to tensorflow/lite/c/version_script.lds
diff --git a/tensorflow/lite/experimental/c/BUILD b/tensorflow/lite/experimental/c/BUILD
deleted file mode 100644
index 8e6b4803155..00000000000
--- a/tensorflow/lite/experimental/c/BUILD
+++ /dev/null
@@ -1,120 +0,0 @@
-load(
-    "//tensorflow/lite:build_def.bzl",
-    "tflite_cc_shared_object",
-    "tflite_copts",
-)
-
-package(
-    default_visibility = [":experimental"],
-    licenses = ["notice"],  # Apache 2.0
-)
-
-package_group(
-    name = "experimental",
-    packages = [
-        "//tensorflow/lite/experimental/...",
-        "//third_party/dart/tflite_native/...",  # whitelisted
-    ],
-)
-
-tflite_cc_shared_object(
-    name = "libtensorflowlite_c.so",
-    linkopts = select({
-        "//tensorflow:macos": [
-            "-Wl,-exported_symbols_list,$(location //tensorflow/lite/experimental/c:exported_symbols.lds)",
-            "-Wl,-install_name,@rpath/libtensorflowlite_c.so",
-        ],
-        "//tensorflow:windows": [],
-        "//conditions:default": [
-            "-z defs",
-            "-Wl,--version-script,$(location //tensorflow/lite/experimental/c:version_script.lds)",
-        ],
-    }),
-    deps = [
-        ":c_api",
-        ":c_api_experimental",
-        ":exported_symbols.lds",
-        ":version_script.lds",
-    ],
-)
-
-cc_library(
-    name = "c_api_internal",
-    srcs = [
-        "c_api.h",
-        "c_api_types.h",
-    ],
-    hdrs = ["c_api_internal.h"],
-    copts = tflite_copts(),
-    visibility = [
-        "//tensorflow/lite/experimental/c:__subpackages__",
-    ],
-    deps = [
-        "//tensorflow/lite:framework",
-        "//tensorflow/lite/c:common",
-    ],
-)
-
-cc_library(
-    name = "c_api",
-    srcs = ["c_api.cc"],
-    hdrs = [
-        "c_api.h",
-        "c_api_types.h",
-    ],
-    copts = tflite_copts(),
-    visibility = [
-        ":experimental",
-    ],
-    deps = [
-        ":c_api_internal",
-        "//tensorflow/lite:framework",
-        "//tensorflow/lite:version",
-        "//tensorflow/lite/c:common",
-        "//tensorflow/lite/kernels:builtin_ops",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "c_api_experimental",
-    srcs = ["c_api_experimental.cc"],
-    hdrs = ["c_api_experimental.h"],
-    copts = tflite_copts(),
-    deps = [
-        ":c_api",
-        ":c_api_internal",
-        "//tensorflow/lite:kernel_api",
-    ],
-    alwayslink = 1,
-)
-
-cc_test(
-    name = "c_api_test",
-    size = "small",
-    srcs = ["c_api_test.cc"],
-    data = [
-        "//tensorflow/lite:testdata/add.bin",
-        "//tensorflow/lite:testdata/add_quantized.bin",
-    ],
-    deps = [
-        ":c_api",
-        "//tensorflow/lite/c:common",
-        "//tensorflow/lite/testing:util",
-        "@com_google_googletest//:gtest",
-    ],
-)
-
-cc_test(
-    name = "c_api_experimental_test",
-    size = "small",
-    srcs = ["c_api_experimental_test.cc"],
-    data = ["//tensorflow/lite:testdata/add.bin"],
-    deps = [
-        ":c_api",
-        ":c_api_experimental",
-        "//tensorflow/lite:kernel_api",
-        "//tensorflow/lite/testing:util",
-        "@com_google_googletest//:gtest",
-    ],
-)
diff --git a/tensorflow/lite/experimental/c/README.md b/tensorflow/lite/experimental/c/README.md
new file mode 100644
index 00000000000..a17f7f8f2c7
--- /dev/null
+++ b/tensorflow/lite/experimental/c/README.md
@@ -0,0 +1 @@
+The C API has been migrated to [lite/c](../../c/README.md).
diff --git a/tensorflow/lite/experimental/c/c_api_types.h b/tensorflow/lite/experimental/c/c_api_types.h
deleted file mode 100644
index b3b0ddc059d..00000000000
--- a/tensorflow/lite/experimental/c/c_api_types.h
+++ /dev/null
@@ -1,673 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// This file defines common C types and APIs for implementing operations,
-// delegates and other constructs in TensorFlow Lite. The actual operations and
-// delegtes can be defined using C++, but the interface between the interpreter
-// and the operations are C.
-//
-// Summary of abstractions
-// TF_LITE_ENSURE - Self-sufficient error checking
-// TfLiteStatus - Status reporting
-// TfLiteIntArray - stores tensor shapes (dims),
-// TfLiteContext - allows an op to access the tensors
-// TfLiteTensor - tensor (a multidimensional array)
-// TfLiteNode - a single node or operation
-// TfLiteRegistration - the implementation of a conceptual operation.
-// TfLiteDelegate - allows delegation of nodes to alternative backends.
-//
-// Some abstractions in this file are created and managed by Interpreter.
-
-#ifndef TENSORFLOW_LITE_C_COMMON_H_
-#define TENSORFLOW_LITE_C_COMMON_H_
-
-#include 
-#include 
-#include 
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
-
-// The list of external context types known to TF Lite. This list exists solely
-// to avoid conflicts and to ensure ops can share the external contexts they
-// need. Access to the external contexts is controled by one of the
-// corresponding support files.
-typedef enum {
-  kTfLiteEigenContext = 0,       // include eigen_support.h to use.
-  kTfLiteGemmLowpContext = 1,    // include gemm_support.h to use.
-  kTfLiteEdgeTpuContext = 2,     // Placeholder for Edge TPU support.
-  kTfLiteCpuBackendContext = 3,  // include cpu_backend_support.h to use.
-  kTfLiteMaxExternalContexts = 4
-} TfLiteExternalContextType;
-
-// Forward declare so dependent structs and methods can reference these types
-// prior to the struct definitions.
-struct TfLiteContext;
-struct TfLiteDelegate;
-struct TfLiteRegistration;
-
-// An external context is a collection of information unrelated to the TF Lite
-// framework, but useful to a subset of the ops. TF Lite knows very little
-// about about the actual contexts, but it keeps a list of them, and is able to
-// refresh them if configurations like the number of recommended threads
-// change.
-typedef struct {
-  TfLiteExternalContextType type;
-  TfLiteStatus (*Refresh)(struct TfLiteContext* context);
-} TfLiteExternalContext;
-
-#define kTfLiteOptionalTensor (-1)
-
-// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
-// indices
-typedef struct {
-  int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
-    __GNUC_MINOR__ >= 1
-  int data[0];
-#else
-  int data[];
-#endif
-} TfLiteIntArray;
-
-// Given the size (number of elements) in a TfLiteIntArray, calculate its size
-// in bytes.
-int TfLiteIntArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteIntArrayFree().
-TfLiteIntArray* TfLiteIntArrayCreate(int size);
-
-// Check if two intarrays are equal. Returns 1 if they are equal, 0 otherwise.
-int TfLiteIntArrayEqual(const TfLiteIntArray* a, const TfLiteIntArray* b);
-
-// Check if an intarray equals an array. Returns 1 if equals, 0 otherwise.
-int TfLiteIntArrayEqualsArray(const TfLiteIntArray* a, int b_size,
-                              const int b_data[]);
-
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteIntArrayFree
-TfLiteIntArray* TfLiteIntArrayCopy(const TfLiteIntArray* src);
-
-// Free memory of array `a`.
-void TfLiteIntArrayFree(TfLiteIntArray* a);
-
-// Fixed size list of floats. Used for per-channel quantization.
-typedef struct {
-  int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
-    __GNUC_MINOR__ >= 1
-  float data[0];
-#else
-  float data[];
-#endif
-} TfLiteFloatArray;
-
-// Given the size (number of elements) in a TfLiteFloatArray, calculate its size
-// in bytes.
-int TfLiteFloatArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteFloatArrayFree().
-TfLiteFloatArray* TfLiteFloatArrayCreate(int size);
-
-// Free memory of array `a`.
-void TfLiteFloatArrayFree(TfLiteFloatArray* a);
-
-// Since we must not depend on any libraries, define a minimal subset of
-// error macros while avoiding names that have pre-conceived meanings like
-// assert and check.
-
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
-#define TF_LITE_ENSURE_MSG(context, value, msg)            \
-  do {                                                     \
-    if (!(value)) {                                        \
-      (context)->ReportError((context), __FILE__ " " msg); \
-      return kTfLiteError;                                 \
-    }                                                      \
-  } while (0)
-
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-#define TF_LITE_ENSURE(context, a)                                          \
-  do {                                                                      \
-    if (!(a)) {                                                             \
-      (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
-                             __LINE__, #a);                                 \
-      return kTfLiteError;                                                  \
-    }                                                                       \
-  } while (0)
-
-#define TF_LITE_ENSURE_STATUS(a) \
-  do {                           \
-    if ((a) != kTfLiteOk) {      \
-      return kTfLiteError;       \
-    }                            \
-  } while (0)
-
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-#define TF_LITE_ENSURE_EQ(context, a, b)                                       \
-  do {                                                                         \
-    if ((a) != (b)) {                                                          \
-      (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
-                             __LINE__, #a, #b, (a), (b));                      \
-      return kTfLiteError;                                                     \
-    }                                                                          \
-  } while (0)
-
-#define TF_LITE_ENSURE_TYPES_EQ(context, a, b)                                 \
-  do {                                                                         \
-    if ((a) != (b)) {                                                          \
-      (context)->ReportError((context), "%s:%d %s != %s (%s != %s)", __FILE__, \
-                             __LINE__, #a, #b, TfLiteTypeGetName(a),           \
-                             TfLiteTypeGetName(b));                            \
-      return kTfLiteError;                                                     \
-    }                                                                          \
-  } while (0)
-
-#define TF_LITE_ENSURE_OK(context, status) \
-  do {                                     \
-    if ((status) != kTfLiteOk) {           \
-      return kTfLiteError;                 \
-    }                                      \
-  } while (0)
-
-// Single-precision complex data type compatible with the C99 definition.
-typedef struct {
-  float re, im;  // real and imaginary parts, respectively.
-} TfLiteComplex64;
-
-// Half precision data type compatible with the C99 definition.
-typedef struct {
-  uint16_t data;
-} TfLiteFloat16;
-
-// Types supported by tensor
-typedef enum {
-  kTfLiteNoType = 0,
-  kTfLiteFloat32 = 1,
-  kTfLiteInt32 = 2,
-  kTfLiteUInt8 = 3,
-  kTfLiteInt64 = 4,
-  kTfLiteString = 5,
-  kTfLiteBool = 6,
-  kTfLiteInt16 = 7,
-  kTfLiteComplex64 = 8,
-  kTfLiteInt8 = 9,
-  kTfLiteFloat16 = 10,
-} TfLiteType;
-
-// Return the name of a given type, for error reporting purposes.
-const char* TfLiteTypeGetName(TfLiteType type);
-
-// SupportedQuantizationTypes.
-typedef enum {
-  // No quantization.
-  kTfLiteNoQuantization = 0,
-  // Affine quantization (with support for per-channel quantization).
-  // Corresponds to TfLiteAffineQuantization.
-  kTfLiteAffineQuantization = 1,
-} TfLiteQuantizationType;
-
-// Structure specifying the quantization used by the tensor, if-any.
-typedef struct {
-  // The type of quantization held by params.
-  TfLiteQuantizationType type;
-  // Holds a reference to one of the quantization param structures specified
-  // below.
-  void* params;
-} TfLiteQuantization;
-
-// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
-// If per-layer quantization is specified this field will still be populated in
-// addition to TfLiteAffineQuantization.
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-//     real_value = scale * (quantized_value - zero_point)
-typedef struct {
-  float scale;
-  int32_t zero_point;
-} TfLiteQuantizationParams;
-
-// Parameters for asymmetric quantization across a dimension (i.e per output
-// channel quantization).
-// quantized_dimension specifies which dimension the scales and zero_points
-// correspond to.
-// For a particular value in quantized_dimension, quantized values can be
-// converted back to float using:
-//     real_value = scale * (quantized_value - zero_point)
-typedef struct {
-  TfLiteFloatArray* scale;
-  TfLiteIntArray* zero_point;
-  int32_t quantized_dimension;
-} TfLiteAffineQuantization;
-
-/* A union of pointers that points to memory for a given tensor. */
-typedef union {
-  /* Do not access these members directly, if possible, use
-   * GetTensorData(tensor) instead, otherwise only access .data, as other
-   * members are deprecated. */
-  int32_t* i32;
-  int64_t* i64;
-  float* f;
-  TfLiteFloat16* f16;
-  char* raw;
-  const char* raw_const;
-  uint8_t* uint8;
-  bool* b;
-  int16_t* i16;
-  TfLiteComplex64* c64;
-  int8_t* int8;
-  /* Only use this member. */
-  void* data;
-} TfLitePtrUnion;
-
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
-typedef enum {
-  kTfLiteMemNone = 0,
-  kTfLiteMmapRo,
-  kTfLiteArenaRw,
-  kTfLiteArenaRwPersistent,
-  kTfLiteDynamic,
-} TfLiteAllocationType;
-
-// The delegates should use zero or positive integers to represent handles.
-// -1 is reserved from unallocated status.
-typedef int TfLiteBufferHandle;
-enum {
-  kTfLiteNullBufferHandle = -1,
-};
-
-// An tensor in the interpreter system which is a wrapper around a buffer of
-// data including a dimensionality (or NULL if not currently defined).
-typedef struct {
-  // The data type specification for data stored in `data`. This affects
-  // what member of `data` union should be used.
-  TfLiteType type;
-  // A union of data pointers. The appropriate type should be used for a typed
-  // tensor based on `type`.
-  TfLitePtrUnion data;
-  // A pointer to a structure representing the dimensionality interpretation
-  // that the buffer should have. NOTE: the product of elements of `dims`
-  // and the element datatype size should be equal to `bytes` below.
-  TfLiteIntArray* dims;
-  // Quantization information.
-  TfLiteQuantizationParams params;
-  // How memory is mapped
-  //  kTfLiteMmapRo: Memory mapped read only.
-  //  i.e. weights
-  //  kTfLiteArenaRw: Arena allocated read write memory
-  //  (i.e. temporaries, outputs).
-  TfLiteAllocationType allocation_type;
-  // The number of bytes required to store the data of this Tensor. I.e.
-  // (bytes of each element) * dims[0] * ... * dims[n-1].  For example, if
-  // type is kTfLiteFloat32 and dims = {3, 2} then
-  // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
-  size_t bytes;
-
-  // An opaque pointer to a tflite::MMapAllocation
-  const void* allocation;
-
-  // Null-terminated name of this tensor.
-  const char* name;
-
-  // The delegate which knows how to handle `buffer_handle`.
-  // WARNING: This is an experimental interface that is subject to change.
-  struct TfLiteDelegate* delegate;
-
-  // An integer buffer handle that can be handled by `delegate`.
-  // The value is valid only when delegate is not null.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteBufferHandle buffer_handle;
-
-  // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
-  // responsible to set data_is_stale to true.
-  // `delegate->CopyFromBufferHandle` can be called to copy the data from
-  // delegate buffer.
-  // WARNING: This is an // experimental interface that is subject to change.
-  bool data_is_stale;
-
-  // True if the tensor is a variable.
-  bool is_variable;
-
-  // Quantization information. Replaces params field above.
-  TfLiteQuantization quantization;
-} TfLiteTensor;
-
-// Free data memory of tensor `t`.
-void TfLiteTensorDataFree(TfLiteTensor* t);
-
-// Free quantization data.
-void TfLiteQuantizationFree(TfLiteQuantization* quantization);
-
-// Free memory of tensor `t`.
-void TfLiteTensorFree(TfLiteTensor* t);
-
-// Set all of a tensor's fields (and free any previously allocated data).
-void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
-                       TfLiteQuantizationParams quantization, char* buffer,
-                       size_t size, TfLiteAllocationType allocation_type,
-                       const void* allocation, bool is_variable,
-                       TfLiteTensor* tensor);
-
-// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
-// types other than kTfLiteDynamic will be ignored.
-void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
-
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs and user defined data, not
-// other features like the type.
-typedef struct {
-  // Inputs to this node expressed as indices into the simulator's tensors.
-  TfLiteIntArray* inputs;
-
-  // Outputs to this node expressed as indices into the simulator's tensors.
-  TfLiteIntArray* outputs;
-
-  // intermediate tensors to this node expressed as indices into the simulator's
-  // tensors.
-  TfLiteIntArray* intermediates;
-
-  // Temporary tensors uses during the computations. This usually contains no
-  // tensors, but ops are allowed to change that if they need scratch space of
-  // any sort.
-  TfLiteIntArray* temporaries;
-
-  // Opaque data provided by the node implementer through `Registration.init`.
-  void* user_data;
-
-  // Opaque data provided to the node if the node is a builtin. This is usually
-  // a structure defined in builtin_op_data.h
-  void* builtin_data;
-
-  // Custom initial data. This is the opaque data provided in the flatbuffer.
-  // WARNING: This is an experimental interface that is subject to change.
-  const void* custom_initial_data;
-  int custom_initial_data_size;
-
-  // The pointer to the delegate. This is non-null only when the node is
-  // created by calling `interpreter.ModifyGraphWithDelegate`.
-  // WARNING: This is an experimental interface that is subject to change.
-  struct TfLiteDelegate* delegate;
-} TfLiteNode;
-
-typedef struct TfLiteContext {
-  // Number of tensors in the context.
-  size_t tensors_size;
-
-  // The execution plan contains a list of the node indices in execution
-  // order. execution_plan->size is the current number of nodes. And,
-  // execution_plan->data[0] is the first node that needs to be run.
-  // TfLiteDelegates can traverse the current execution plan by iterating
-  // through each member of this array and using GetNodeAndRegistration() to
-  // access details about a node. i.e.
-  // TfLiteIntArray* execution_plan;
-  // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
-  // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
-  //    int node_index = execution_plan->data[exec_index];
-  //    TfLiteNode* node;
-  //    TfLiteRegistration* reg;
-  //    context->GetNodeAndRegistration(context, node_index, &node, ®);
-  // }
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
-                                   TfLiteIntArray** execution_plan);
-
-  // An array of tensors in the interpreter context (of length `tensors_size`)
-  TfLiteTensor* tensors;
-
-  // opaque full context ptr (an opaque c++ data structure)
-  void* impl_;
-
-  // Request memory pointer be resized. Updates dimensions on the tensor.
-  // NOTE: ResizeTensor takes ownership of newSize.
-  TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
-                               TfLiteIntArray* new_size);
-  // Request that an error be reported with format string msg.
-  void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
-
-  // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries.  If
-  // non-null, the value pointed to by `first_new_tensor_index` will be set to
-  // the index of the first new tensor.
-  TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
-                             int* first_new_tensor_index);
-
-  // Get a Tensor node by node_index.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*GetNodeAndRegistration)(
-      struct TfLiteContext*, int node_index, TfLiteNode** node,
-      struct TfLiteRegistration** registration);
-
-  // Replace ops with one or more stub delegate operations. This function
-  // does not take ownership of `nodes_to_replace`.
-  TfLiteStatus (*ReplaceNodeSubsetsWithDelegateKernels)(
-      struct TfLiteContext*, struct TfLiteRegistration registration,
-      const TfLiteIntArray* nodes_to_replace, struct TfLiteDelegate* delegate);
-
-  // Number of threads that are recommended to subsystems like gemmlowp and
-  // eigen.
-  int recommended_num_threads;
-
-  // Access external contexts by type.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
-                                               TfLiteExternalContextType);
-  // Set the value of a external context. Does not take ownership of the
-  // pointer.
-  // WARNING: This is an experimental interface that is subject to change.
-  void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
-                             TfLiteExternalContext*);
-
-  // Flag for allowing float16 precision for FP32 calculation.
-  // default: false.
-  // WARNING: This is an experimental API and subject to change.
-  bool allow_fp32_relax_to_fp16;
-
-  // Pointer to the op-level profiler, if set; nullptr otherwise.
-  void* profiler;
-
-  // Allocate memory for op data. This method should only be used in `Init`
-  // method and the allocated memory will be available until `Free` method is
-  // called.
-  // On TFL, it allocates memory from heap using malloc, but for micro, this
-  // will be allocating from the allocator.
-  // WARNING: This is an experimental interface that is subject to change.
-  void* (*AllocateOpData)(struct TfLiteContext* ctx, size_t size);
-
-  // Deallocate memory holding op data. This method should only be used inside
-  // `Free` method. Caller needs to make sure that that `buffer` is allocated by
-  // `AllocateOpData` method.
-  // On TFL, it will free the buffer, and for micro, this method is a no-op.
-  // WARNING: This is an experimental interface that is subject to change.
-  void (*DeallocateOpData)(struct TfLiteContext* ctx, void* buffer);
-
-  // Allocate a temporary tensor to the node. This method also makes a copy of
-  // the shape array internally so the shape array could be deallocated right
-  // afterwards. WARNING: This is an experimental interface that is subject to
-  // change.
-  TfLiteStatus (*AllocateTemporaryTensor)(struct TfLiteContext* ctx,
-                                          TfLiteNode* node, int dims,
-                                          int* shape, TfLiteType data_type,
-                                          TfLiteAllocationType allocation_type,
-                                          int* new_tensor_index);
-
-  // Deallocate all temporary tensors associated to the node (including
-  // kTfLiteArenaRwPersistent persistent tensors). It also deallocates
-  // all the shape tensors.
-  // WARNING: This is an experimental interface that is subject to change.
-  void (*DeallocateAllTemporaryTensors)(struct TfLiteContext* ctx,
-                                        TfLiteNode* node);
-
-  // Resize the memory pointer of the `tensor`. This method behaves the same as
-  // `ResizeTensor`, except that it makes a copy of the shape array internally
-  // so the shape array could be deallocated right afterwards.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*ResizeTensorExplicit)(struct TfLiteContext* ctx,
-                                       TfLiteTensor* tensor, int dims,
-                                       const int* shape);
-} TfLiteContext;
-
-typedef struct TfLiteRegistration {
-  // Initializes the op from serialized data.
-  // If a built-in op:
-  //   `buffer` is the op's params data (TfLiteLSTMParams*).
-  //   `length` is zero.
-  // If custom op:
-  //   `buffer` is the op's `custom_options`.
-  //   `length` is the size of the buffer.
-  //
-  // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
-  // or an instance of a struct).
-  //
-  // The returned pointer will be stored with the node in the `user_data` field,
-  // accessible within prepare and invoke functions below.
-  // NOTE: if the data is already in the desired format, simply implement this
-  // function to return `nullptr` and implement the free function to be a no-op.
-  void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
-
-  // The pointer `buffer` is the data previously returned by an init invocation.
-  void (*free)(TfLiteContext* context, void* buffer);
-
-  // prepare is called when the inputs this node depends on have been resized.
-  // context->ResizeTensor() can be called to request output tensors to be
-  // resized.
-  //
-  // Returns kTfLiteOk on success.
-  TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
-
-  // Execute the node (should read node->inputs and output to node->outputs).
-  // Returns kTfLiteOk on success.
-  TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
-
-  // profiling_string is called during summarization of profiling information
-  // in order to group executions together. Providing a value here will cause a
-  // given op to appear multiple times is the profiling report. This is
-  // particularly useful for custom ops that can perform significantly
-  // different calculations depending on their `user-data`.
-  const char* (*profiling_string)(const TfLiteContext* context,
-                                  const TfLiteNode* node);
-
-  // Builtin codes. If this kernel refers to a builtin this is the code
-  // of the builtin. This is so we can do marshaling to other frameworks like
-  // NN API.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  int32_t builtin_code;
-
-  // Custom op name. If the op is a builtin, this will be null.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  // WARNING: This is an experimental interface that is subject to change.
-  const char* custom_name;
-
-  // The version of the op.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  int version;
-} TfLiteRegistration;
-
-// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
-// values should be 1, 2, 4, 8, ...etc.
-typedef enum {
-  kTfLiteDelegateFlagsNone = 0,
-  // The flag is set if the delegate can handle dynamic sized tensors.
-  // For example, the output shape of a `Resize` op with non-constant shape
-  // can only be inferred when the op is invoked.
-  // In this case, the Delegate is responsible for calling
-  // `SetTensorToDynamic` to mark the tensor as a dynamic tensor, and calling
-  // `ResizeTensor` when invoking the op.
-  //
-  // If the delegate isn't capable to handle dynamic tensors, this flag need
-  // to be set to false.
-  kTfLiteDelegateFlagsAllowDynamicTensors = 1
-} TfLiteDelegateFlags;
-
-// WARNING: This is an experimental interface that is subject to change.
-typedef struct TfLiteDelegate {
-  // Data that delegate needs to identify itself. This data is owned by the
-  // delegate. The delegate is owned in the user code, so the delegate is
-  // responsible for doing this when it is destroyed.
-  void* data_;
-
-  // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
-  // delegate a view of the current graph through TfLiteContext*. It typically
-  // will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels()
-  // to ask the TensorFlow lite runtime to create macro-nodes to represent
-  // delegated subgraphs of the original graph.
-  TfLiteStatus (*Prepare)(TfLiteContext* context,
-                          struct TfLiteDelegate* delegate);
-
-  // Copy the data from delegate buffer handle into raw memory of the given
-  // 'tensor'. This cannot be null. The delegate is allowed to allocate the raw
-  // bytes as long as it follows the rules for kTfLiteDynamic tensors.
-  TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
-                                       struct TfLiteDelegate* delegate,
-                                       TfLiteBufferHandle buffer_handle,
-                                       TfLiteTensor* tensor);
-
-  // Copy the data from raw memory of the given 'tensor' to delegate buffer
-  // handle. This can be null if the delegate doesn't use its own buffer.
-  TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
-                                     struct TfLiteDelegate* delegate,
-                                     TfLiteBufferHandle buffer_handle,
-                                     TfLiteTensor* tensor);
-
-  // Free the Delegate Buffer Handle. Note: This only frees the handle, but
-  // this doesn't release the underlying resource (e.g. textures). The
-  // resources are either owned by application layer or the delegate.
-  // This can be null if the delegate doesn't use its own buffer.
-  void (*FreeBufferHandle)(TfLiteContext* context,
-                           struct TfLiteDelegate* delegate,
-                           TfLiteBufferHandle* handle);
-
-  // Bitmask flags. See the comments in `TfLiteDelegateFlags`.
-  int64_t flags;
-} TfLiteDelegate;
-
-// Build a 'null' delegate, with all the fields properly set to their default
-// values.
-TfLiteDelegate TfLiteDelegateCreate();
-
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateDelegateParams` function in `interpreter.cc` details.
-typedef struct {
-  TfLiteDelegate* delegate;
-  TfLiteIntArray* nodes_to_replace;
-  TfLiteIntArray* input_tensors;
-  TfLiteIntArray* output_tensors;
-} TfLiteDelegateParams;
-
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
-#endif  // TENSORFLOW_LITE_C_COMMON_H_
diff --git a/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
index 6daca3e4f5c..cbd1d016b83 100644
--- a/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
+++ b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
@@ -6,21 +6,18 @@ Unity by way of a C# `Interpreter` wrapper.
 
 Note that the native TF Lite plugin(s) *must* be built before using the Unity
 Plugin, and placed in Assets/TensorFlowLite/SDK/Plugins/. For the editor (note
-that this has only been tested on Linux; the syntax may differ on Mac/Windows):
+that the generated shared library name and suffix are platform-dependent):
 
 ```sh
-bazel build -c opt --cxxopt=--std=c++11 \
-  //tensorflow/lite/experimental/c:libtensorflowlite_c.so
+bazel build -c opt --cxxopt=--std=c++11 //tensorflow/lite/c:tensorflowlite_c
 ```
 
 and for Android (replace `android_arm` with `android_arm64` for 64-bit):
 
 ```sh
 bazel build -c opt --cxxopt=--std=c++11 --config=android_arm \
-  //tensorflow/lite/experimental/c:libtensorflowlite_c.so
+  //tensorflow/lite/c:tensorflowlite_c
 ```
 
 If you encounter issues with native plugin discovery on Mac ("Darwin")
-platforms, try renaming `libtensorflowlite_c.so` to `tensorflowlite_c.bundle`.
-Similarly, on Windows you'll likely need to rename `libtensorflowlite_c.so` to
-`tensorflowlite_c.dll`.
+platforms, try renaming `libtensorflowlite_c.dylib` to `tensorflowlite_c.bundle`.
diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple
index 2ccb207f19b..6ecd3d589ea 100644
--- a/tensorflow/lite/experimental/ios/BUILD.apple
+++ b/tensorflow/lite/experimental/ios/BUILD.apple
@@ -5,23 +5,20 @@ load("//tensorflow/lite/experimental/ios:ios.bzl", "TFL_MINIMUM_OS_VERSION")
 load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
 
 package(
-    default_visibility = ["//tensorflow/lite/experimental/c:experimental"],
+    default_visibility = ["//tensorflow/lite/c:experimental"],
     licenses = ["notice"],  # Apache 2.0
 )
 
 TFL_LIBRARY_HDRS = [
     "//tensorflow/lite/delegates/gpu:metal_delegate.h",
-    "//tensorflow/lite/experimental/c:c_api.h",
-]
-
-TFL_FRAMEWORK_HDRS = TFL_LIBRARY_HDRS + [
-    "//tensorflow/lite/experimental/c:c_api_types.h",
+    "//tensorflow/lite/c:c_api.h",
+    "//tensorflow/lite/c:common.h",
 ]
 
 # bazel build -c opt --config=ios_fat //tensorflow/lite/experimental/ios:TensorFlowLiteC_framework
 ios_static_framework(
     name = "TensorFlowLiteC_framework",
-    hdrs = TFL_FRAMEWORK_HDRS,
+    hdrs = TFL_LIBRARY_HDRS,
     bundle_name = "TensorFlowLiteC",
     minimum_os_version = TFL_MINIMUM_OS_VERSION,
     deps = [
@@ -32,7 +29,7 @@ ios_static_framework(
 # bazel build -c opt --config=ios --ios_multi_cpus=armv7,arm64,x86_64 //tensorflow/lite/experimental/ios:TensorFlowLiteCWithSelectTfOps_framework
 ios_static_framework(
     name = "TensorFlowLiteCWithSelectTfOps_framework",
-    hdrs = TFL_FRAMEWORK_HDRS,
+    hdrs = TFL_LIBRARY_HDRS,
     bundle_name = "TensorFlowLiteC",
     minimum_os_version = TFL_MINIMUM_OS_VERSION,
     deps = [
@@ -68,8 +65,8 @@ cc_library(
     hdrs = TFL_LIBRARY_HDRS,
     tags = ["nobuilder"],
     deps = [
+        "//tensorflow/lite/c:c_api",
         "//tensorflow/lite/delegates/gpu:metal_delegate",
-        "//tensorflow/lite/experimental/c:c_api",
     ],
 )
 
diff --git a/tensorflow/lite/experimental/objc/BUILD.apple b/tensorflow/lite/experimental/objc/BUILD.apple
index 09e672ceff3..198e90c1cbc 100644
--- a/tensorflow/lite/experimental/objc/BUILD.apple
+++ b/tensorflow/lite/experimental/objc/BUILD.apple
@@ -44,7 +44,7 @@ RELEASE_COPTS = [
     # Warns if an @selector() expression is encountered with a method name that hasn't been defined yet.
     "-Wundeclared-selector",
     # Turn off warnings for headers not part of TensorFlow Lite Objective-C API.
-    "--system-header-prefix=tensorflow/lite/experimental/c/",
+    "--system-header-prefix=tensorflow/lite/c/",
 ]
 
 # Compiler flags for building test libraries.
@@ -63,7 +63,7 @@ objc_library(
     tags = TFL_DEFAULT_TAGS,
     visibility = ios_visibility_whitelist(),
     deps = [
-        "//tensorflow/lite/experimental/c:c_api",
+        "//tensorflow/lite/c:c_api",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen b/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
index feacdbad8de..bbd35902dce 100644
--- a/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
+++ b/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
@@ -1,7 +1,7 @@
 {
   "sourceFilters" : [
     "tensorflow/lite",
-    "tensorflow/lite/experimental/c",
+    "tensorflow/lite/c",
     "tensorflow/lite/experimental/objc",
     "tensorflow/lite/experimental/objc/apis",
     "tensorflow/lite/experimental/objc/apps/TestApp/TestApp",
diff --git a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC-nightly.podspec b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC-nightly.podspec
index 762ba7b83c9..2447f432664 100644
--- a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC-nightly.podspec
+++ b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC-nightly.podspec
@@ -25,7 +25,7 @@ Pod::Spec.new do |s|
   s.source_files = [
     objc_dir + '{apis,sources}/*.{h,m,mm}',
     tfl_dir + 'experimental/c/c_api.h',
-    tfl_dir + 'experimental/c/c_api_types.h',
+    tfl_dir + 'experimental/c/common.h',
   ]
   s.module_map = objc_dir + 'apis/framework.modulemap'
   s.dependency 'TensorFlowLiteC', "~> #{s.version}"
diff --git a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
index 3af0eff111e..b3ece575fd8 100644
--- a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
+++ b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
@@ -25,7 +25,7 @@ Pod::Spec.new do |s|
   s.source_files = [
     objc_dir + '{apis,sources}/*.{h,m,mm}',
     tfl_dir + 'experimental/c/c_api.h',
-    tfl_dir + 'experimental/c/c_api_types.h',
+    tfl_dir + 'experimental/c/common.h',
   ]
   s.module_map = objc_dir + 'apis/framework.modulemap'
   s.dependency 'TensorFlowLiteC', "#{s.version}"
diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm
index e8e69484e21..8ef4c571558 100644
--- a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm
+++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm
@@ -20,7 +20,7 @@
 #import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h"
 #import "tensorflow/lite/experimental/objc/apis/TFLTensor.h"
 
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/c_api.h"
 
 NS_ASSUME_NONNULL_BEGIN
 
diff --git a/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen b/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
index 7ad7e33cf09..d919ada871d 100644
--- a/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
+++ b/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
@@ -1,6 +1,6 @@
 {
   "sourceFilters" : [
-    "tensorflow/lite/experimental/c",
+    "tensorflow/lite/c",
     "tensorflow/lite/experimental/swift",
     "tensorflow/lite/experimental/swift/Sources",
     "tensorflow/lite/experimental/swift/TestApp/TestApp",

From 05018f92bfb1a9ee839d754c35f6f5e14f4c617b Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Wed, 27 Nov 2019 13:46:20 -0800
Subject: [PATCH 060/279] [spirv] Add folders for spv.IAdd and spv.IMul

Adding zero and multiplying one can be common when generating code
for index calculation.

This CL also sorted canonicalize.mlir to alphabetical order.

PiperOrigin-RevId: 282828055
Change-Id: I27a442882bb41de819100c1e3fea1621afc3440d
---
 .../mlir/Dialect/SPIRV/SPIRVArithmeticOps.td  |  4 +++
 .../mlir/lib/Dialect/SPIRV/SPIRVOps.cpp       | 30 +++++++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
index cbcd9303626..00ce72f5b2a 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
@@ -292,6 +292,8 @@ def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd", SPV_Integer, [Commutative]> {
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -328,6 +330,8 @@ def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul", SPV_Integer, [Commutative]> {
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6bb052d49d7..ae7643fa915 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
@@ -1518,6 +1519,35 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.IAdd
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IAddOp::fold(ArrayRef operands) {
+  assert(operands.size() == 2 && "spv.IAdd expects two operands");
+  // lhs + 0 = lhs
+  if (matchPattern(operand2(), m_Zero()))
+    return operand1();
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IMulOp::fold(ArrayRef operands) {
+  assert(operands.size() == 2 && "spv.IMul expects two operands");
+  // lhs * 0 == 0
+  if (matchPattern(operand2(), m_Zero()))
+    return operand2();
+  // lhs * 1 = lhs
+  if (matchPattern(operand2(), m_One()))
+    return operand1();
+
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // spv.LoadOp
 //===----------------------------------------------------------------------===//

From 0837d8785b52af8ccb38fdb29f6ffe0c6d106715 Mon Sep 17 00:00:00 2001
From: Karim Nosir 
Date: Wed, 27 Nov 2019 13:49:40 -0800
Subject: [PATCH 061/279] - Add relu1 op in tf lite mlir - Add pattern for tf
 lite relu1 - Add type constraints for relu/relu6 that matches the tf lite
 kernel.

PiperOrigin-RevId: 282828589
Change-Id: Ife8a716f3ba5606edf8e2823b0c703f6236c76e4
---
 tensorflow/compiler/mlir/lite/ir/tfl_ops.td   | 23 +++++--
 .../compiler/mlir/lite/tests/legalize-tf.mlir | 68 +++++++++----------
 .../compiler/mlir/lite/tests/optimize.mlir    | 22 ++++++
 .../mlir/lite/transforms/optimize_patterns.td | 16 +++++
 .../lite/testing/generate_examples_lib.py     |  1 +
 tensorflow/lite/testing/op_tests/binary_op.py |  5 +-
 tensorflow/lite/testing/op_tests/relu1.py     |  3 +-
 7 files changed, 97 insertions(+), 41 deletions(-)

diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index e91f5fa1e8e..cfca2745be7 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -1873,9 +1873,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
       x -> max(0, x)
   }];
 
-  let arguments = (ins AnyTensor:$x);
+  let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
 
-  let results = (outs AnyTensor:$y);
+  let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
 }
 
 def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
@@ -1888,9 +1888,24 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
       x -> max(0, min(6, x))
   }];
 
-  let arguments = (ins AnyTensor:$x);
+  let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
 
-  let results = (outs AnyTensor:$y);
+  let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
+}
+
+def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
+                                  SameOperandsAndResultShape,
+                                  SameOperandsAndResultsScale]> {
+  let summary = "Relu1 operator";
+
+  let description = [{
+    Element-wise Relu1 operator
+      x -> max(-1, min(1, x))
+  }];
+
+  let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
+
+  let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
 }
 
 def TFL_ReshapeOp: TFL_Op<"reshape", [
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index c2653f3d6f1..27eff39c397 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -1,22 +1,22 @@
 // RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
 
-func @addRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
-  %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
-  %1 = "tf.Add"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
-  %2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
-  %3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
-  %4 = "tf.Add"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
-  %5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
-  %6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
-  %7 = "tf.Relu6"(%6) : (tensor<1xi32>) -> tensor<1xi32>
-  return %7: tensor<1xi32>
+func @addRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
+  %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
+  %3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+  %4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
+  %6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32>
+  return %7: tensor<1xf32>
 
 // CHECK-LABEL: addRelu
-// CHECK:  tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
-// CHECK:  %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
-// CHECK:  %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
-// CHECK:  %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
-// CHECK:  %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xi32>
+// CHECK:  tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
+// CHECK:  %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
+// CHECK:  %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+// CHECK:  %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
+// CHECK:  %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xf32>
 // CHECK:  return
 }
 
@@ -244,32 +244,32 @@ func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
 // CHECK:  "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
 }
 
-func @divRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
-  %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
-  %1 = "tf.Div"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
-  %2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
-  %3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
-  %4 = "tf.Div"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
-  %5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
-  return %5: tensor<1xi32>
+func @divRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
+  %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
+  %3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+  %4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
+  return %5: tensor<1xf32>
 
 // CHECK-LABEL: divRelu
-// CHECK:  tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
-// CHECK:  %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
-// CHECK:  %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
-// CHECK:  %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
+// CHECK:  tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
+// CHECK:  %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
+// CHECK:  %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+// CHECK:  %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
 // CHECK:  return
 }
 
-func @squaredDifferenceRelu(tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> {
-^bb0(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>):
-  %0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
-  %1 = "tf.Relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
-  return %1: tensor<1xi32>
+func @squaredDifferenceRelu(tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> {
+^bb0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>):
+  %0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %1 = "tf.Relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  return %1: tensor<1xf32>
 
 // CHECK-LABEL: squaredDifferenceRelu
-// CHECK:  tfl.squared_difference %arg0, %arg1 : tensor<1xi32>
-// CHECK:  %1 = "tfl.relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+// CHECK:  tfl.squared_difference %arg0, %arg1 : tensor<1xf32>
+// CHECK:  %1 = "tfl.relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 // CHECK:  return
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index 7d63d8df11b..aaf6eac2d85 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -600,3 +600,25 @@ func @squeezeToReshape(%arg0: tensor<1x1x2xf32>) -> tensor<2xf32> {
   // CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32>
   // CHECK: return %[[RESULT]]
 }
+
+// CHECK-LABEL: Relu1
+func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  %cst = constant dense<-1.0> : tensor
+  %cst1 = constant dense<1.0> : tensor
+  %0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32>
+  %1 = "tfl.minimum"(%0, %cst1) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32>
+  return %1 : tensor<2x3xf32>
+
+  // CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
+}
+
+// CHECK-LABEL: Relu1_2
+func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  %cst = constant dense<-1.0> : tensor
+  %cst1 = constant dense<1.0> : tensor
+  %0 = "tfl.minimum"(%arg0, %cst1) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32>
+  %1 = "tfl.maximum"(%0, %cst) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32>
+  return %1 : tensor<2x3xf32>
+
+  // CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 78a14f3b409..92276ac0089 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -273,3 +273,19 @@ def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
           (TFL_ReshapeOp $input,
            (ConstantOp (GetShape $squeeze_op))),
           [(AnyStaticShapeTensor $squeeze_op)]>;
+
+class ValueEquals : Constraint().getNumElements() == 1 &&"
+  "*$0.cast().getValues().begin() == " # val>>;
+
+def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
+                          (ConstantOp $NegOne)),
+           (ConstantOp $One)),
+          (TFL_Relu1Op $input),
+          [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
+
+def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
+                          (ConstantOp $One)),
+           (ConstantOp $NegOne)),
+          (TFL_Relu1Op $input),
+          [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py
index f57e25c68b5..1d257e1f3c7 100644
--- a/tensorflow/lite/testing/generate_examples_lib.py
+++ b/tensorflow/lite/testing/generate_examples_lib.py
@@ -237,6 +237,7 @@ class Options(object):
     # test sets.
     # TODO(juhoha): Separate the state from the options.
     self.multi_gen_state = None
+    self.use_experimental_converter = False
 
 
 def _prepare_dir(options):
diff --git a/tensorflow/lite/testing/op_tests/binary_op.py b/tensorflow/lite/testing/op_tests/binary_op.py
index c9900dc288d..88702b0542f 100644
--- a/tensorflow/lite/testing/op_tests/binary_op.py
+++ b/tensorflow/lite/testing/op_tests/binary_op.py
@@ -129,7 +129,10 @@ def make_binary_op_tests(options,
         name="input2",
         shape=parameters["input_shape_2"])
     out = binary_operator(input1, input2)
-    if parameters["activation"]:
+    # TODO(karimnosseir): Update condition after moving to new converter.
+    if parameters["activation"] and (not options.use_experimental_converter or
+                                     (parameters["dtype"] != tf.int32 and
+                                      parameters["dtype"] != tf.int64)):
       out = tf.nn.relu(out)
     return [input1, input2], [out]
 
diff --git a/tensorflow/lite/testing/op_tests/relu1.py b/tensorflow/lite/testing/op_tests/relu1.py
index 21c03c89454..ac92bac1cb2 100644
--- a/tensorflow/lite/testing/op_tests/relu1.py
+++ b/tensorflow/lite/testing/op_tests/relu1.py
@@ -30,8 +30,7 @@ def make_relu1_tests(options):
 
   # Chose a set of parameters
   test_parameters = [{
-      "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
-                      [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+      "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]],
       "fully_quantize": [True, False],
       "input_range": [(-2, 8)]
   }]

From 8f1ddaaa0f4a658aedd446628a7c0181a510c9de Mon Sep 17 00:00:00 2001
From: Derek Murray 
Date: Wed, 27 Nov 2019 14:08:03 -0800
Subject: [PATCH 062/279] Optimize registration and deregistration of child
 CancellationManager objects.

This change replaces the map of std::function callbacks in the parent
with a doubly-linked list of child CancellationManager objects. To
avoid unnecessary allocations, the link pointers are stored within
each child CancellationManager, which can be stack-allocated.

PiperOrigin-RevId: 282831932
Change-Id: I044ae86f7868442ada53010ea871cdf71f53eab7
---
 tensorflow/core/framework/cancellation.cc     | 89 ++++++++++++++++---
 tensorflow/core/framework/cancellation.h      | 21 ++++-
 .../core/framework/cancellation_test.cc       | 32 +++++++
 3 files changed, 130 insertions(+), 12 deletions(-)

diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
index 35c4be63ce4..a91442fcbad 100644
--- a/tensorflow/core/framework/cancellation.cc
+++ b/tensorflow/core/framework/cancellation.cc
@@ -15,6 +15,8 @@ limitations under the License.
 
 #include "tensorflow/core/framework/cancellation.h"
 
+#include 
+
 #include "absl/memory/memory.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/logging.h"
@@ -29,20 +31,13 @@ CancellationManager::CancellationManager()
       next_cancellation_token_(0) {}
 
 CancellationManager::CancellationManager(CancellationManager* parent)
-    : is_cancelling_(false),
-      is_cancelled_(false),
-      next_cancellation_token_(0),
-      parent_(parent),
-      parent_token_(parent->get_cancellation_token()) {
-  bool registered = parent->RegisterCallback(parent_token_,
-                                             [this]() { this->StartCancel(); });
-  if (!registered) {
-    is_cancelled_ = true;
-  }
+    : is_cancelling_(false), next_cancellation_token_(0), parent_(parent) {
+  is_cancelled_ = parent->RegisterChild(this);
 }
 
 void CancellationManager::StartCancel() {
   gtl::FlatMap callbacks_to_run;
+  std::forward_list children_to_cancel;
   Notification* cancelled_notification = nullptr;
   {
     mutex_lock l(mu_);
@@ -52,6 +47,16 @@ void CancellationManager::StartCancel() {
     is_cancelling_ = true;
     if (state_) {
       std::swap(state_->callbacks, callbacks_to_run);
+
+      // Remove all children from the list of children.
+      CancellationManager* child = state_->first_child;
+      while (child != nullptr) {
+        children_to_cancel.push_front(child);
+        child->is_removed_from_parent_ = true;
+        child = child->next_sibling_;
+      }
+      state_->first_child = nullptr;
+
       cancelled_notification = &state_->cancelled_notification;
     }
   }
@@ -63,6 +68,9 @@ void CancellationManager::StartCancel() {
   for (auto key_and_value : callbacks_to_run) {
     key_and_value.second();
   }
+  for (CancellationManager* child : children_to_cancel) {
+    child->StartCancel();
+  }
   {
     mutex_lock l(mu_);
     is_cancelling_ = false;
@@ -113,6 +121,65 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) {
   }
 }
 
+bool CancellationManager::RegisterChild(CancellationManager* child) {
+  mutex_lock l(mu_);
+  if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
+    child->is_removed_from_parent_ = true;
+    return true;
+  }
+
+  if (!state_) {
+    state_ = absl::make_unique();
+  }
+
+  // Push `child` onto the front of the list of children.
+  CancellationManager* current_head = state_->first_child;
+  state_->first_child = child;
+  child->prev_sibling_ = nullptr;
+  child->next_sibling_ = current_head;
+  if (current_head) {
+    current_head->prev_sibling_ = child;
+  }
+
+  return false;
+}
+
+void CancellationManager::DeregisterChild(CancellationManager* child) {
+  DCHECK_EQ(child->parent_, this);
+  Notification* cancelled_notification = nullptr;
+  {
+    mutex_lock l(mu_);
+    if (!child->is_removed_from_parent_) {
+      // Remove the child from this manager's list of children.
+      DCHECK(state_);
+
+      if (child->prev_sibling_ == nullptr) {
+        // The child was at the head of the list.
+        DCHECK_EQ(state_->first_child, child);
+        state_->first_child = child->next_sibling_;
+      } else {
+        child->prev_sibling_->next_sibling_ = child->next_sibling_;
+      }
+
+      if (child->next_sibling_ != nullptr) {
+        child->next_sibling_->prev_sibling_ = child->prev_sibling_;
+      }
+
+      child->is_removed_from_parent_ = true;
+    }
+    if (is_cancelling_) {
+      cancelled_notification = &state_->cancelled_notification;
+    }
+  }
+
+  // Wait for an ongoing call to StartCancel() to finish. This wait ensures that
+  // the caller of DeregisterChild does not return immediately and free a child
+  // that may currently be being cancelled by StartCancel().
+  if (cancelled_notification) {
+    cancelled_notification->WaitForNotification();
+  }
+}
+
 bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
   mutex_lock lock(mu_);
   if (is_cancelled_ || is_cancelling_) {
@@ -127,7 +194,7 @@ bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
 
 CancellationManager::~CancellationManager() {
   if (parent_) {
-    parent_->DeregisterCallback(parent_token_);
+    parent_->DeregisterChild(this);
   }
   if (state_) {
     StartCancel();
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index 7a98ee992a9..f9f1e0d19f6 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -147,14 +147,33 @@ class CancellationManager {
   struct State {
     Notification cancelled_notification;
     gtl::FlatMap callbacks;
+
+    // If this CancellationManager has any children, this member points to the
+    // head of a doubly-linked list of its children.
+    CancellationManager* first_child;  // Not owned.
   };
 
+  bool RegisterChild(CancellationManager* child);
+  void DeregisterChild(CancellationManager* child);
+
   bool is_cancelling_;
   std::atomic_bool is_cancelled_;
   std::atomic next_cancellation_token_;
 
   CancellationManager* const parent_ = nullptr;  // Not owned.
-  const CancellationToken parent_token_ = kInvalidToken;
+
+  // If this CancellationManager is associated with a parent, this member will
+  // be set to `true` after this is removed from the parent's list of children.
+  bool is_removed_from_parent_ GUARDED_BY(parent_->mu_) = false;
+
+  // If this CancellationManager is associated with a parent, these members form
+  // a doubly-linked list of that parent's children.
+  //
+  // These fields are valid only when `this->is_removed_from_parent_` is false.
+  CancellationManager* prev_sibling_ GUARDED_BY(parent_->mu_) =
+      nullptr;  // Not owned.
+  CancellationManager* next_sibling_ GUARDED_BY(parent_->mu_) =
+      nullptr;  // Not owned.
 
   mutex mu_;
   std::unique_ptr state_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
index df73526d594..e4994350ddd 100644
--- a/tensorflow/core/framework/cancellation_test.cc
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -15,7 +15,11 @@ limitations under the License.
 
 #include "tensorflow/core/framework/cancellation.h"
 
+#include 
+#include 
+#include 
 #include 
+
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/platform/test.h"
@@ -199,4 +203,32 @@ TEST(Cancellation, Parent_AlreadyCancelled) {
   EXPECT_TRUE(child.IsCancelled());
 }
 
+TEST(Cancellation, Parent_RandomDestructionOrder) {
+  CancellationManager parent;
+  std::random_device rd;
+  std::mt19937 g(rd());
+
+  // To cover the linked-list codepaths, perform multiple randomized rounds of
+  // registering and deregistering children with `parent`.
+  for (int rounds = 0; rounds < 100; ++rounds) {
+    std::vector> children;
+
+    // 1. Register a random number of children with the parent.
+    std::uniform_int_distribution dist(1, 9);
+    const size_t round_size = dist(rd);
+    for (size_t i = 0; i < round_size; ++i) {
+      children.push_back(absl::make_unique(&parent));
+      EXPECT_FALSE(children.back()->IsCancelled());
+    }
+
+    // 2. Deregister the children in a random order.
+    std::vector destruction_order(round_size);
+    std::iota(destruction_order.begin(), destruction_order.end(), 0);
+    std::shuffle(destruction_order.begin(), destruction_order.end(), g);
+    for (size_t index : destruction_order) {
+      children[index].reset();
+    }
+  }
+}
+
 }  // namespace tensorflow

From fe924581b32b990509c46c50ce51adbb334a895d Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Wed, 27 Nov 2019 14:12:32 -0800
Subject: [PATCH 063/279] [spirv] NFC: Add getZero() and getOne() static method
 to ConstantOp

Getting constant zero or one is very common so it merits a special handy
method on spirv::ConstantOp itself.

PiperOrigin-RevId: 282832572
Change-Id: Ifb6fe54acef73f7ce2af6b995bb06b94a35fd294
---
 .../mlir/Dialect/SPIRV/SPIRVLowering.h        |  2 --
 .../include/mlir/Dialect/SPIRV/SPIRVOps.h     |  2 ++
 .../mlir/Dialect/SPIRV/SPIRVStructureOps.td   |  7 +++++
 .../ConvertStandardToSPIRV.cpp                |  3 +-
 .../mlir/lib/Dialect/SPIRV/SPIRVOps.cpp       | 29 +++++++++++++++++++
 .../Transforms/LowerABIAttributesPass.cpp     |  4 +--
 6 files changed, 41 insertions(+), 6 deletions(-)

diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index a5b3fc27413..8faa90cb134 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -56,8 +56,6 @@ public:
 protected:
   /// Type lowering class.
   SPIRVTypeConverter &typeConverter;
-
-private:
 };
 
 #include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc"
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
index 104a4798e7c..353004b6c76 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -26,6 +26,8 @@
 #include "mlir/IR/Function.h"
 
 namespace mlir {
+class OpBuilder;
+
 namespace spirv {
 
 #define GET_OP_CLASSES
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index 1ec825aab5c..34b386ebc17 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -118,6 +118,13 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
   let extraClassDeclaration = [{
     // Returns true if a constant can be built for the given `type`.
     static bool isBuildableWith(Type type);
+
+    // Creates a constant zero/one of the given `type` at the current insertion
+    // point of `builder` and returns it.
+    static spirv::ConstantOp getZero(Type type, Location loc,
+                                     OpBuilder *builder);
+    static spirv::ConstantOp getOne(Type type, Location loc,
+                                    OpBuilder *builder);
   }];
 
   let hasOpcode = 0;
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 62cabf66a0d..4a3d25fbd38 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -145,8 +145,7 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc,
 
   // Need to add a '0' at the beginning of the index list for accessing into the
   // struct that wraps the nested array types.
-  Value *zero = builder.create(
-      loc, indexType, builder.getIntegerAttr(indexType, 0));
+  Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
   SmallVector accessIndices;
   accessIndices.reserve(1 + indices.size());
   accessIndices.push_back(zero);
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ae7643fa915..e82420022ea 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1169,6 +1169,35 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
   return true;
 }
 
+spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
+                                             OpBuilder *builder) {
+  if (auto intType = type.dyn_cast()) {
+    unsigned width = intType.getWidth();
+    Attribute val;
+    if (width == 1)
+      return builder->create(loc, type,
+                                                builder->getBoolAttr(false));
+    return builder->create(
+        loc, type, builder->getIntegerAttr(type, APInt(width, 0)));
+  }
+
+  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
+}
+
+spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
+                                            OpBuilder *builder) {
+  if (auto intType = type.dyn_cast()) {
+    unsigned width = intType.getWidth();
+    if (width == 1)
+      return builder->create(loc, type,
+                                                builder->getBoolAttr(true));
+    return builder->create(
+        loc, type, builder->getIntegerAttr(type, APInt(width, 1)));
+  }
+
+  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
+}
+
 //===----------------------------------------------------------------------===//
 // spv.ControlBarrier
 //===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index e9d36f66369..d48b31fe491 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -194,8 +194,8 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands,
     if (isScalarOrVectorType(argType.value())) {
       auto indexType =
           typeConverter.convertType(IndexType::get(funcOp.getContext()));
-      auto zero = rewriter.create(
-          funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+      auto zero =
+          spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);
       auto loadPtr = rewriter.create(
           funcOp.getLoc(), replacement, zero.constant());
       replacement = rewriter.create(funcOp.getLoc(), loadPtr,

From 2141a3344bbbaf6a9bb3dc770a410542962f5ff5 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Wed, 27 Nov 2019 14:14:54 -0800
Subject: [PATCH 064/279] Add unit test for reduce_mean with dtype=uint8.

PiperOrigin-RevId: 282832948
Change-Id: Ic05aa1b16b9a858d82243725da363dd0c5194234
---
 tensorflow/python/kernel_tests/reduction_ops_test.py | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index acc66b7c3e6..152e3a3bbf2 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -438,6 +438,12 @@ class MeanReductionTest(BaseReductionTest):
       np_arr = self._makeIncremental((2,) * rank, dtypes.int32)
       self._compareAllAxes(np_arr)
 
+  @test_util.run_deprecated_v1
+  def testUint8(self):
+    for rank in range(1, _MAX_RANK + 1):
+      np_arr = self._makeRandom((2,) * rank, dtypes.uint8)
+      self._compareAllAxes(np_arr)
+
   @test_util.run_deprecated_v1
   def testFloat32(self):
     for rank in range(1, _MAX_RANK + 1):

From c9d1b6dc3ada4aa88e97725b11ffb2371e75f355 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Wed, 27 Nov 2019 14:17:37 -0800
Subject: [PATCH 065/279] Set --incompatible_remove_legacy_whole_archive to
 False A roll-forward of cl/281126040 The windows build failure that caused
 the rollback is addressed in cl/282539273

PiperOrigin-RevId: 282833339
Change-Id: I36a4ea4b188880265a80cc52f229e26004b56b17
---
 .bazelrc | 15 +++++++++++++--
 1 file changed, 13 insertions(+), 2 deletions(-)

diff --git a/.bazelrc b/.bazelrc
index 3ad93cdf49f..7219e9e23c2 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -212,8 +212,19 @@ build --announce_rc
 # Other build flags.
 build --define=grpc_no_ares=true
 
-# Prevent regression of https://github.com/bazelbuild/bazel/issues/7362
-build --incompatible_remove_legacy_whole_archive
+# See https://github.com/bazelbuild/bazel/issues/7362 for information on what
+# --incompatible_remove_legacy_whole_archive flag does.
+# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate
+# Tensorflow to the default, however test coverage wasn't enough to catch the
+# errors.
+# There is ongoing work on Bazel team's side to provide support for transitive
+# shared libraries. As part of migrating to transitive shared libraries, we
+# hope to provide a better mechanism for control over symbol exporting, and
+# then tackle this issue again.
+#
+# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library
+# archives in -whole_archive -no_whole_archive.
+build --noincompatible_remove_legacy_whole_archive
 
 # Modular TF build options
 build:dynamic_kernels --define=dynamic_loaded_kernels=true

From 2e5d9fe645cb3910f88f1d82746ba47b296aa069 Mon Sep 17 00:00:00 2001
From: Karim Nosir 
Date: Wed, 27 Nov 2019 14:30:44 -0800
Subject: [PATCH 066/279] Add clarifying comment for GetShape

PiperOrigin-RevId: 282835297
Change-Id: I16d995948f790d1b71c1f856f9e85bc6ebb945d3
---
 tensorflow/compiler/mlir/lite/transforms/optimize.cc          | 2 ++
 tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td | 2 ++
 2 files changed, 4 insertions(+)

diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 43a84b4406a..d8697a8c4e0 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -140,6 +140,8 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
   return ExpandTo4DForConvImpl(a, true);
 }
 
+// Returns shape of a ranked tensor.
+// Precondition: output_val's is ranked tensor.
 DenseElementsAttr GetShape(Value *output_val) {
   auto output_type = output_val->getType().cast();
   auto shape_vector = output_type.getShape();
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 92276ac0089..905f01d8413 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -267,6 +267,8 @@ multiclass FuseTileBroadcastIntoFollowingBinary {
 foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
   in defm : FuseTileBroadcastIntoFollowingBinary;
 
+// Returns shape of a ranked tensor.
+// if called without a ranked tensor it will fail.
 def GetShape: NativeCodeCall<"GetShape($0)">;
 
 def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),

From 7e1a830fc64e2ead11f096d0fc6f92af4e4f1b90 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Wed, 27 Nov 2019 14:38:29 -0800
Subject: [PATCH 067/279] Fix a typo in tensorflow/python/BUILD

PiperOrigin-RevId: 282836419
Change-Id: Ia8efb306c6bf736a2a77a23de9b4799123b908fe
---
 tensorflow/python/BUILD | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 3d4f22583a2..613f20e097c 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3687,9 +3687,9 @@ py_library(
     srcs = ["ops/math_ops.py"],
     srcs_version = "PY2AND3",
     deps = [
-        "constant_op",
         ":array_ops",
         ":common_shapes",
+        ":constant_op",
         ":control_flow_ops_gen",
         ":data_flow_ops_gen",
         ":dtypes",

From a5f2c547fdac9ebba4ae7db7320a251e55f196a9 Mon Sep 17 00:00:00 2001
From: Derek Murray 
Date: Wed, 27 Nov 2019 14:49:13 -0800
Subject: [PATCH 068/279] [CancellationManager] Initialize `State::first_child`
 to nullptr.

PiperOrigin-RevId: 282838111
Change-Id: Iedc9bd8a615acab0572f1c4cf0c134e4cd177c20
---
 tensorflow/core/framework/cancellation.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index f9f1e0d19f6..3e1727ae54a 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -150,7 +150,7 @@ class CancellationManager {
 
     // If this CancellationManager has any children, this member points to the
     // head of a doubly-linked list of its children.
-    CancellationManager* first_child;  // Not owned.
+    CancellationManager* first_child = nullptr;  // Not owned.
   };
 
   bool RegisterChild(CancellationManager* child);

From 4d7f24698f4b5450168279f71cca8b16ef4b0d8c Mon Sep 17 00:00:00 2001
From: Derek Murray 
Date: Wed, 27 Nov 2019 15:03:09 -0800
Subject: [PATCH 069/279] In OpKernelContext, move all optional
 tracking-related members into TrackingState.

The present OpKernelContext has several members (including two mutexes and various vectors) that are only rarely used when various forms of memory tracking are enabled. This change moves those members into a separate struct that is only created when tracking is enabled.

PiperOrigin-RevId: 282840238
Change-Id: I87269aa26b716e71ab686afd7b2cc75f1e4ca51c
---
 tensorflow/core/framework/op_kernel.cc | 120 +++++++++++++------------
 tensorflow/core/framework/op_kernel.h  |  57 +++++++-----
 2 files changed, 98 insertions(+), 79 deletions(-)

diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 66bb57f736b..959be0781be 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -277,10 +277,11 @@ OpKernelContext::OpKernelContext(Params* params)
           params, static_cast(params->op_kernel->output_types().size())) {}
 
 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
-    : params_(params),
-      outputs_(num_outputs),
-      temp_memory_allocated_(0),
-      persistent_memory_allocated_(0) {
+    : params_(params), outputs_(num_outputs) {
+  if (params_->record_tensor_accesses || params_->track_allocations) {
+    tracking_state_ = absl::make_unique();
+  }
+
   params_->ensure_eigen_gpu_device();
   if (params_->eigen_gpu_device != nullptr) {
     Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
@@ -291,9 +292,6 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
       SetStatus(s);
     }
   }
-  if (params_->record_tensor_accesses) {
-    referenced_tensors_.Init();
-  }
 }
 
 OpKernelContext::~OpKernelContext() {
@@ -302,12 +300,12 @@ OpKernelContext::~OpKernelContext() {
       delete value.tensor;
     }
   }
-  if (params_->record_tensor_accesses) referenced_tensors_.Destroy();
-  if (params_->track_allocations && !wrapped_allocators_.empty()) {
+  if (params_->track_allocations &&
+      !tracking_state_->wrapped_allocators.empty()) {
     LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
                  << "being consumed by the StepStatsCollector.";
-    for (auto& wrapped_alloator : wrapped_allocators_) {
-      wrapped_alloator.second->GetRecordsAndUnRef();
+    for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) {
+      wrapped_allocator.second->GetRecordsAndUnRef();
     }
   }
 }
@@ -321,15 +319,17 @@ Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
     allocator = params_->device->GetAllocator(attr);
   }
   if (TF_PREDICT_FALSE(track_allocations())) {
-    mutex_lock lock(mu_);
-    for (const auto& wrapped : wrapped_allocators_) {
+    DCHECK(tracking_state_);
+    mutex_lock lock(tracking_state_->mu);
+    for (const auto& wrapped : tracking_state_->wrapped_allocators) {
       if (wrapped.first == allocator) {
         return wrapped.second;
       }
     }
     TrackingAllocator* wrapped_allocator =
         new TrackingAllocator(allocator, params_->track_allocations);
-    wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
+    tracking_state_->wrapped_allocators.push_back(
+        std::make_pair(allocator, wrapped_allocator));
     return wrapped_allocator;
   } else {
     return allocator;
@@ -341,9 +341,10 @@ void OpKernelContext::SetStatus(const Status& status) {
 }
 
 void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
-  mutex_lock l(mu_);
+  DCHECK(tracking_state_);
+  mutex_lock l(tracking_state_->mu);
   // Keep a reference to the underlying memory around.
-  referenced_tensors_->Add(tensor);
+  tracking_state_->referenced_tensors.Add(tensor);
 }
 
 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
@@ -804,8 +805,9 @@ Status OpKernelContext::allocate_temp(
       record_temp_memory_allocation(alloc_size, *out_temp);
     }
   } else if (record_memory_consumption_) {
-    mutex_lock l(stats_mu_);
-    temp_memory_allocated_ += out_temp->TotalBytes();
+    DCHECK(tracking_state_);
+    mutex_lock l(tracking_state_->stats_mu);
+    tracking_state_->temp_memory_allocated += out_temp->TotalBytes();
   }
   return s;
 }
@@ -917,20 +919,18 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) {
     record_tensor_reference(tensor);
     outputs_[index] = TensorValue(new Tensor(tensor));
     if (track_allocations() && tensor.TotalBytes() > 0) {
-      mutex_lock l(stats_mu_);
-      if (!temp_tensor_buffer_and_size_) {
-        return;
-      }
+      DCHECK(tracking_state_);
+      mutex_lock l(tracking_state_->stats_mu);
       const auto it = std::find_if(
-          temp_tensor_buffer_and_size_->begin(),
-          temp_tensor_buffer_and_size_->end(),
+          tracking_state_->temp_tensor_buffer_and_size.begin(),
+          tracking_state_->temp_tensor_buffer_and_size.end(),
           [&tensor](const std::pair& e) {
             return e.first ==
                    static_cast(tensor.tensor_data().data());
           });
-      if (it != temp_tensor_buffer_and_size_->end()) {
-        temp_memory_allocated_ -= it->second;
-        temp_tensor_buffer_and_size_->erase(it);
+      if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
+        tracking_state_->temp_memory_allocated -= it->second;
+        tracking_state_->temp_tensor_buffer_and_size.erase(it);
       }
     }
   }
@@ -1000,57 +1000,67 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
 
 void OpKernelContext::record_temp_memory_allocation(int64 size,
                                                     const Tensor& t) {
-  mutex_lock l(stats_mu_);
-  temp_memory_allocated_ += size;
-  if (!temp_tensor_buffer_and_size_) {
-    temp_tensor_buffer_and_size_.reset(
-        new gtl::InlinedVector, 2>());
+  if (tracking_state_) {
+    mutex_lock l(tracking_state_->stats_mu);
+    tracking_state_->temp_memory_allocated += size;
+    tracking_state_->temp_tensor_buffer_and_size.emplace_back(
+        static_cast(t.tensor_data().data()), size);
   }
-  temp_tensor_buffer_and_size_->emplace_back(
-      static_cast(t.tensor_data().data()), size);
 }
 
 int64 OpKernelContext::temp_memory_allocated() const {
-  mutex_lock l(stats_mu_);
-  return temp_memory_allocated_;
+  if (tracking_state_) {
+    mutex_lock l(tracking_state_->stats_mu);
+    return tracking_state_->temp_memory_allocated;
+  } else {
+    return 0;
+  }
 }
 
 void OpKernelContext::record_persistent_memory_allocation(int64 size,
                                                           int64 alloc_id) {
-  mutex_lock l(stats_mu_);
-  persistent_memory_allocated_ += size;
-  if (alloc_id >= 0) {
-    if (!persistent_alloc_ids_) {
-      persistent_alloc_ids_.reset(new gtl::InlinedVector());
+  if (tracking_state_) {
+    mutex_lock l(tracking_state_->stats_mu);
+    tracking_state_->persistent_memory_allocated += size;
+    if (alloc_id >= 0) {
+      tracking_state_->persistent_alloc_ids.push_back(alloc_id);
     }
-    persistent_alloc_ids_->push_back(alloc_id);
   }
 }
 
 int64 OpKernelContext::persistent_memory_allocated() const {
-  mutex_lock l(stats_mu_);
-  return persistent_memory_allocated_;
+  if (tracking_state_) {
+    mutex_lock l(tracking_state_->stats_mu);
+    return tracking_state_->persistent_memory_allocated;
+  } else {
+    return 0;
+  }
 }
 
 std::vector OpKernelContext::persistent_alloc_ids() const {
-  mutex_lock l(stats_mu_);
-  if (persistent_alloc_ids_) {
-    return std::vector(persistent_alloc_ids_->begin(),
-                              persistent_alloc_ids_->end());
+  if (tracking_state_) {
+    mutex_lock l(tracking_state_->stats_mu);
+    return std::vector(tracking_state_->persistent_alloc_ids.begin(),
+                              tracking_state_->persistent_alloc_ids.end());
   } else {
     return std::vector();
   }
 }
 
 void OpKernelContext::clear_recorded_memory() {
-  mutex_lock l(stats_mu_);
-  temp_memory_allocated_ = 0;
-  persistent_memory_allocated_ = 0;
-  if (temp_tensor_buffer_and_size_) {
-    temp_tensor_buffer_and_size_->clear();
+  if (tracking_state_) {
+    mutex_lock l(tracking_state_->stats_mu);
+    tracking_state_->temp_memory_allocated = 0;
+    tracking_state_->persistent_memory_allocated = 0;
+    tracking_state_->temp_tensor_buffer_and_size.clear();
+    tracking_state_->persistent_alloc_ids.clear();
   }
-  if (persistent_alloc_ids_) {
-    persistent_alloc_ids_->clear();
+}
+
+void OpKernelContext::set_record_memory_consumption(bool v) {
+  record_memory_consumption_ = v;
+  if (v && !tracking_state_) {
+    tracking_state_ = absl::make_unique();
   }
 }
 
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 149667a9965..7275fb0484d 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1090,9 +1090,11 @@ class OpKernelContext {
   }
 
   gtl::InlinedVector ConsumeWrappedAllocators() {
-    mutex_lock lock(mu_);
     gtl::InlinedVector retrieved;
-    retrieved.swap(wrapped_allocators_);
+    if (tracking_state_) {
+      mutex_lock lock(tracking_state_->mu);
+      retrieved.swap(tracking_state_->wrapped_allocators);
+    }
     return retrieved;
   }
 
@@ -1233,27 +1235,29 @@ class OpKernelContext {
   // Records temp memory allocation. Tensor object is recorded to identify the
   // case where temp memory is used as output memory.
   void record_temp_memory_allocation(int64 size, const Tensor& t)
-      LOCKS_EXCLUDED(stats_mu_);
+      LOCKS_EXCLUDED(tracking_state_->stats_mu);
 
   // Returns recorded size of temporary memory;
-  int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
+  int64 temp_memory_allocated() const LOCKS_EXCLUDED(tracking_state_->stats_mu);
 
   // Records persistent memory allocation, size can be negative indicating
   // deallocation.
   void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
-      LOCKS_EXCLUDED(stats_mu_);
+      LOCKS_EXCLUDED(tracking_state_->stats_mu);
 
   // Returns recorded size and ids of persistent memory.
-  int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
+  int64 persistent_memory_allocated() const
+      LOCKS_EXCLUDED(tracking_state_->stats_mu);
 
-  std::vector persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_);
+  std::vector persistent_alloc_ids() const
+      LOCKS_EXCLUDED(tracking_state_->stats_mu);
 
   // Resets counters for temp and persistent memory and recorded ids.
-  void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_);
+  void clear_recorded_memory() LOCKS_EXCLUDED(tracking_state_->stats_mu);
 
   bool input_is_ref(int index) const;
 
-  void set_record_memory_consumption(bool v) { record_memory_consumption_ = v; }
+  void set_record_memory_consumption(bool v);
 
   // Used by OpKernel implementations to track actively running deferred ops.
   //
@@ -1312,26 +1316,30 @@ class OpKernelContext {
   Status status_;
   friend class CollectiveExecutor;  // for access to params_
   Params* params_;                  // not owned
-  mutable mutex mu_;  // mutable so const accessors can acquire the lock
-  gtl::InlinedVector wrapped_allocators_ GUARDED_BY(mu_);
   gtl::InlinedVector outputs_;
 
   // Keep track of calls to ScopedAllocator.
   // TODO(ayushd): change to absl::flat_hash_set.
   std::unique_ptr> allocated_scope_ids_;
 
-  // Constructed only if record_tensor_accesses>.
-  ManualConstructor referenced_tensors_ GUARDED_BY(mu_);
-
   // The following data members are only used when allocation tracking is
-  // enabled.
-  mutable mutex stats_mu_;
-  int64 temp_memory_allocated_ GUARDED_BY(stats_mu_);
-  int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_);
-  std::unique_ptr, 2>>
-      temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_);
-  std::unique_ptr> persistent_alloc_ids_
-      GUARDED_BY(stats_mu_);
+  // enabled, memory consumption is being recorded, or tensor access is being
+  // recorded.
+  struct TrackingState {
+    mutable mutex mu;
+    gtl::InlinedVector wrapped_allocators GUARDED_BY(mu);
+
+    UniqueTensorReferences referenced_tensors GUARDED_BY(mu);
+
+    mutable mutex stats_mu;
+    int64 temp_memory_allocated GUARDED_BY(stats_mu) = 0;
+
+    int64 persistent_memory_allocated GUARDED_BY(stats_mu) = 0;
+    gtl::InlinedVector, 2>
+        temp_tensor_buffer_and_size GUARDED_BY(stats_mu);
+    gtl::InlinedVector persistent_alloc_ids GUARDED_BY(stats_mu);
+  };
+  std::unique_ptr tracking_state_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
 };
@@ -1618,8 +1626,9 @@ inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
 inline void OpKernelContext::retrieve_accessed_tensors(
     TensorReferenceVector* out_vector) {
   if (params_->record_tensor_accesses) {
-    mutex_lock l(mu_);
-    referenced_tensors_->FreezeAndReturnReferences(out_vector);
+    DCHECK(tracking_state_);
+    mutex_lock l(tracking_state_->mu);
+    tracking_state_->referenced_tensors.FreezeAndReturnReferences(out_vector);
   }
 }
 

From d059de3e87f470fec0eb7c26d6bc917efe6a600a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Wed, 27 Nov 2019 15:17:52 -0800
Subject: [PATCH 070/279] Enable maximum/minimum single-op tests.

PiperOrigin-RevId: 282842467
Change-Id: Ia434183c1b218c6c32ecfa3fa150a3aadd12b252
---
 tensorflow/lite/testing/BUILD | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD
index b893ee0524f..25da7cedf01 100644
--- a/tensorflow/lite/testing/BUILD
+++ b/tensorflow/lite/testing/BUILD
@@ -491,8 +491,10 @@ edgetpu_ops = [
     "depthwiseconv",  # high error
     "fully_connected",
     "l2norm",  # high error
+    "maximum",
     "max_pool",
     "mean",
+    "minimum",
     "mul",
     "pad",  # high error
     "relu",

From 03e56f176cc9ed0c7dfa8870fab73e8651b8cac9 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Wed, 27 Nov 2019 15:38:48 -0800
Subject: [PATCH 071/279] Remove dependence on core/lib/core:stringpiece

PiperOrigin-RevId: 282845433
Change-Id: I36bd91b278b31d6238d8b0c81328b39c03cb95ba
---
 tensorflow/core/platform/cloud/BUILD                  | 9 ++++++++-
 tensorflow/core/platform/cloud/curl_http_request.h    | 2 +-
 tensorflow/core/platform/cloud/file_block_cache.h     | 2 +-
 tensorflow/core/platform/cloud/http_request.h         | 2 +-
 tensorflow/core/platform/cloud/http_request_fake.h    | 2 +-
 tensorflow/core/platform/cloud/ram_file_block_cache.h | 2 +-
 tensorflow/core/platform/env_test.cc                  | 2 +-
 7 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 27321b3be0e..d578b1a2388 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -35,7 +35,10 @@ cc_library(
     name = "file_block_cache",
     hdrs = ["file_block_cache.h"],
     copts = tf_copts(),
-    deps = ["//tensorflow/core:lib"],
+    deps = [
+        "//tensorflow/core:lib",
+        "//tensorflow/core/platform:stringpiece",
+    ],
 )
 
 cc_library(
@@ -47,6 +50,7 @@ cc_library(
     deps = [
         ":file_block_cache",
         "//tensorflow/core:lib",
+        "//tensorflow/core/platform:stringpiece",
     ],
 )
 
@@ -139,6 +143,7 @@ cc_library(
     deps = [
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:stringpiece",
     ],
 )
 
@@ -151,6 +156,7 @@ cc_library(
         ":http_request",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:stringpiece",
         "@curl",
     ],
 )
@@ -167,6 +173,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
+        "//tensorflow/core/platform:stringpiece",
         "@curl",
     ],
 )
diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h
index ddb1599e871..2e0e368a32b 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.h
+++ b/tensorflow/core/platform/cloud/curl_http_request.h
@@ -22,12 +22,12 @@ limitations under the License.
 
 #include 
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/cloud/http_request.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/file_block_cache.h b/tensorflow/core/platform/cloud/file_block_cache.h
index 3e66a9937a6..d2453016a1c 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.h
+++ b/tensorflow/core/platform/cloud/file_block_cache.h
@@ -23,11 +23,11 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/notification.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
 
diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h
index 91825b5958a..209e51407f9 100644
--- a/tensorflow/core/platform/cloud/http_request.h
+++ b/tensorflow/core/platform/cloud/http_request.h
@@ -21,11 +21,11 @@ limitations under the License.
 #include 
 
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h
index f1bed661715..df0fe9eeb6b 100644
--- a/tensorflow/core/platform/cloud/http_request_fake.h
+++ b/tensorflow/core/platform/cloud/http_request_fake.h
@@ -24,11 +24,11 @@ limitations under the License.
 #include 
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
 
diff --git a/tensorflow/core/platform/cloud/ram_file_block_cache.h b/tensorflow/core/platform/cloud/ram_file_block_cache.h
index d418a0fb6b0..97105ff046a 100644
--- a/tensorflow/core/platform/cloud/ram_file_block_cache.h
+++ b/tensorflow/core/platform/cloud/ram_file_block_cache.h
@@ -23,12 +23,12 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/cloud/file_block_cache.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/notification.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
 
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 1da8aaab743..8298df9a817 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -20,13 +20,13 @@ limitations under the License.
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/cord.h"
 #include "tensorflow/core/platform/null_file_system.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {

From a10dc733566818da48aec3cd94f2c1ee7c64bfb4 Mon Sep 17 00:00:00 2001
From: Derek Murray 
Date: Wed, 27 Nov 2019 15:51:50 -0800
Subject: [PATCH 072/279] Add PrivateIntraProcessRendezvous.

PrivateIntraProcessRendezvous is a version of the existing IntraProcessRendezvous (now renamed to RefcountedIntraProcessRendezvous with a forwarding alias) that is compatible with stack allocation. It allows users to avoid the overhead of dynamically allocating/destroying an IntraProcessRendezvous and the atomic operations involved in manipulating its reference count.

This change modifies some users of IntraProcessRendezvous to use PrivateIntraProcessRendezvous, where appropriate. In particular, it uses a stack-allocated PrivateIntraProcessRendezvous on the DirectSession::RunInternal() path.

PiperOrigin-RevId: 282847328
Change-Id: I3c54024ea658afb2e2bd27ef35dc421653abc1a8
---
 .../core/common_runtime/direct_session.cc     |  9 +-
 .../core/common_runtime/direct_session.h      |  4 +-
 tensorflow/core/common_runtime/executor.h     | 10 +-
 tensorflow/core/common_runtime/function.cc    | 12 +--
 .../core/common_runtime/function_test.cc      |  5 +-
 .../process_function_library_runtime_test.cc  | 39 ++++----
 .../core/common_runtime/rendezvous_mgr.cc     | 97 +++++++++++++------
 .../core/common_runtime/rendezvous_mgr.h      | 67 +++++++------
 8 files changed, 144 insertions(+), 99 deletions(-)

diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 133a6c31a93..5d42ec208f7 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -521,7 +521,8 @@ Status DirectSession::RunInternal(
                             executor_step_count, &debugger_state));
   }
 
-  run_state.rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
+  PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
+
 #ifndef __ANDROID__
   // Set up for collectives if ExecutorsAndKeys declares a key.
   if (executors_and_keys->collective_graph_key !=
@@ -616,7 +617,7 @@ Status DirectSession::RunInternal(
   Executor::Args args;
   args.step_id = step_id;
   args.call_frame = call_frame;
-  args.rendezvous = run_state.rendez.get();
+  args.rendezvous = &rendezvous;
   args.collective_executor =
       (run_state.collective_executor ? run_state.collective_executor->get()
                                      : nullptr);
@@ -695,7 +696,7 @@ Status DirectSession::RunInternal(
     // `barrier` will delete itself after the final executor finishes.
     Notification executors_done;
     ExecutorBarrier* barrier =
-        new ExecutorBarrier(num_executors, run_state.rendez.get(),
+        new ExecutorBarrier(num_executors, &rendezvous,
                             [&run_state, &executors_done](const Status& ret) {
                               {
                                 mutex_lock l(run_state.mu);
@@ -1139,7 +1140,7 @@ Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
 
 Status DirectSession::RecvPRunOutputs(
     const std::vector& output_names,
-    const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
+    const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
     std::vector* outputs) {
   Status s;
   if (!output_names.empty()) {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index a272633b4e2..7bbb198ef44 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -191,7 +191,6 @@ class DirectSession : public Session {
   struct RunState {
     mutex mu;
     Status status GUARDED_BY(mu);
-    core::RefCountPtr rendez = nullptr;
     std::unique_ptr collective_executor;
     std::unique_ptr collector;
     TensorStore tensor_store;
@@ -208,6 +207,7 @@ class DirectSession : public Session {
     Notification executors_done;
     std::unordered_map pending_inputs;   // true if fed
     std::unordered_map pending_outputs;  // true if fetched
+    core::RefCountPtr rendez = nullptr;
 
     PartialRunState(const std::vector& pending_input_names,
                     const std::vector& pending_output_names,
@@ -282,7 +282,7 @@ class DirectSession : public Session {
   // tensors are computed.
   ::tensorflow::Status RecvPRunOutputs(
       const std::vector& output_names,
-      const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
+      const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
       std::vector* outputs);
 
   // Check if the specified fetches can be computed from the feeds
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index c147deee694..8e6e9bd9336 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -166,8 +166,8 @@ class ExecutorBarrier {
   //
   // 'done' is called after the last executor completes, and
   // ExecutorBarrier is deleted.
-  ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
-      : rendez_(r), done_cb_(done), pending_(num) {}
+  ExecutorBarrier(size_t num, RendezvousInterface* r, StatusCallback done)
+      : rendez_(r), done_cb_(std::move(done)), pending_(num) {}
 
   ~ExecutorBarrier() {}
 
@@ -178,7 +178,7 @@ class ExecutorBarrier {
   }
 
  private:
-  Rendezvous* rendez_ = nullptr;
+  RendezvousInterface* rendez_ = nullptr;  // Not owned.
   StatusCallback done_cb_ = nullptr;
 
   mutable mutex mu_;
@@ -186,7 +186,7 @@ class ExecutorBarrier {
   StatusGroup status_group_ GUARDED_BY(mu_);
 
   void WhenDone(const Status& s) {
-    Rendezvous* error_rendez = nullptr;
+    RendezvousInterface* error_rendez = nullptr;
     StatusCallback done = nullptr;
     Status status;
 
@@ -197,7 +197,6 @@ class ExecutorBarrier {
       // Rendezvous object by this thread only.
       if (status_group_.ok() && !s.ok()) {
         error_rendez = rendez_;
-        error_rendez->Ref();
       }
 
       if (!s.ok() && !StatusGroup::IsDerived(s) &&
@@ -219,7 +218,6 @@ class ExecutorBarrier {
     if (error_rendez != nullptr) {
       error_rendez->StartAbort(
           errors::Aborted("Stopping remaining executors."));
-      error_rendez->Unref();
     }
 
     if (done != nullptr) {
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index aa3be38fd29..501002e1f7f 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1116,11 +1116,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
   }
   Options run_opts = opts;
   if (opts.create_rendezvous) {
-    Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
+    auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
     run_opts.rendezvous = rendezvous;
     run_opts.create_rendezvous = false;
-    done = [done = std::move(done), rendezvous](const Status& status) {
-      rendezvous->Unref();
+    done = [done = std::move(done), rendezvous](const Status& status) mutable {
+      delete rendezvous;
       done(status);
     };
   }
@@ -1187,11 +1187,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
 
   Options run_opts = opts;
   if (opts.create_rendezvous) {
-    Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
+    auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
     run_opts.rendezvous = rendezvous;
     run_opts.create_rendezvous = false;
-    done = [done = std::move(done), rendezvous](const Status& status) {
-      rendezvous->Unref();
+    done = [done = std::move(done), rendezvous](const Status& status) mutable {
+      delete rendezvous;
       done(status);
     };
   }
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 7c76c469d1e..89e4daa50b3 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -1854,8 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
 
   Tensor y;
   FunctionLibraryRuntime::Options opts;
-  Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_.get());
-  opts.rendezvous = rendezvous;
+  PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
+  opts.rendezvous = &rendezvous;
   opts.source_device = "/device:CPU:1";
   // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
   TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
@@ -1870,7 +1870,6 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
       y,
       test::AsTensor({"/job:localhost/replica:0/task:0/device:CPU:1"},
                               TensorShape({})));
-  rendezvous->Unref();
 }
 
 namespace {
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 1a5ed3caa11..55bc408f9c5 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -110,12 +110,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
     }
   }
 
-  ~ProcessFunctionLibraryRuntimeTest() override {
-    if (rendezvous_ != nullptr) {
-      rendezvous_->Unref();
-    }
-  }
-
   void Init(const std::vector& flib,
             const SessionMetadata* session_metadata = nullptr) {
     FunctionDefLibrary proto;
@@ -127,7 +121,8 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
         device_mgr_.get(), Env::Default(), /*config=*/nullptr,
         TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(),
         nullptr, session_metadata));
-    rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
+    rendezvous_ =
+        absl::make_unique(device_mgr_.get());
   }
 
   Status Instantiate(
@@ -263,7 +258,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
           test::function::FunctionTestSchedClosure(fn);
         };
 
-    opts.rendezvous = rendezvous_;
+    opts.rendezvous = rendezvous_.get();
     opts.runner = &runner;
     Status status;
     Notification done;
@@ -292,7 +287,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
   std::unique_ptr lib_def_;
   std::unique_ptr cluster_flr_;
   std::unique_ptr proc_flr_;
-  IntraProcessRendezvous* rendezvous_ = nullptr;
+  std::unique_ptr rendezvous_ = nullptr;
 };
 
 TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
@@ -344,7 +339,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
   Init({test::function::XTimesTwo()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -359,7 +354,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -375,7 +370,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
   auto x = test::AsTensor({1, 2, 3, 4});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -392,7 +387,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:1";
@@ -411,7 +406,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   Tensor y;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts_0;
@@ -432,7 +427,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
@@ -462,7 +457,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
@@ -509,7 +504,7 @@ void TestTwoDeviceMult(
     const string& error = "") {
   fixture->Init({test::function::TwoDeviceMult()});
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = fixture->rendezvous_;
+  opts.rendezvous = fixture->rendezvous_.get();
   auto x = test::AsTensor({1, 2, 3});
   Tensor y_cpu;
   Tensor y_gpu;
@@ -542,7 +537,7 @@ void TestTwoDeviceInputOutput(
   fixture->Init({test::function::TwoDeviceInputOutput()});
 
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = fixture->rendezvous_;
+  opts.rendezvous = fixture->rendezvous_.get();
   Tensor x1 = test::AsTensor({1, 2});
   if (absl::StrContains(inst_opts.input_devices[0], "GPU")) {
     x1 = fixture->CPUToGPU(x1);
@@ -743,7 +738,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
 
   // Run the function taking a resource and outputing it
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   Tensor x1 = CPUToGPU(test::AsTensor({1, 2}));
   Tensor x2 = GetResourceHandle("my_gpu_var", mgr->default_container(),
                                 "/job:a/replica:0/task:0/device:GPU:0");
@@ -985,7 +980,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataAbsent) {
   Init({SessionMetadataReaderOpFn()}, /*session_metadata=*/nullptr);
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -1001,7 +996,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
   Init({SessionMetadataReaderOpFn()}, &session_metadata);
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -1027,7 +1022,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
   TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc
index 0d5e79667db..6ed7df2cc1e 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.cc
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc
@@ -32,23 +32,12 @@ limitations under the License.
 
 namespace tensorflow {
 
-IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
-    : device_mgr_(device_mgr) {}
-
-IntraProcessRendezvous::~IntraProcessRendezvous() {}
-
-Status IntraProcessRendezvous::Send(const ParsedKey& key,
-                                    const Rendezvous::Args& args,
-                                    const Tensor& val, const bool is_dead) {
-  VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
-  // Buffers "val" and "device_context" in local_.
-  return local_.Send(key, args, val, is_dead);
-}
-
-void IntraProcessRendezvous::SameWorkerRecvDone(
-    const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
-    const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
-    StatusCallback done) {
+namespace {
+void SameWorkerRecvDone(const DeviceMgr* device_mgr,
+                        const Rendezvous::ParsedKey& parsed,
+                        const Rendezvous::Args& send_args,
+                        const Rendezvous::Args& recv_args, const Tensor& in,
+                        Tensor* out, StatusCallback done) {
   // Do a quick copy (sharing the underlying buffer) if both tensors
   // are on host memory.
   const bool src_host =
@@ -73,13 +62,13 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
   }
 
   Device* src_device;
-  Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
+  Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
   if (!s.ok()) {
     done(s);
     return;
   }
   Device* dst_device;
-  s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
+  s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
   if (!s.ok()) {
     done(s);
     return;
@@ -116,16 +105,18 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
 }
 
-void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
-                                       const Rendezvous::Args& args,
-                                       DoneCallback done) {
-  VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
+void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
+                               LocalRendezvous* local,
+                               const RendezvousInterface::ParsedKey& parsed,
+                               const Rendezvous::Args& recv_args,
+                               RendezvousInterface::DoneCallback done) {
+  VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
 
   MEMDEBUG_CACHE_OP("RecvAsync");
   // Recv the tensor from local_.
-  local_.RecvAsync(
-      key, args,
-      [this, key, done = std::move(done)](
+  local->RecvAsync(
+      parsed, recv_args,
+      [device_mgr, parsed, done = std::move(done)](
           const Status& status, const Rendezvous::Args& send_args,
           const Rendezvous::Args& recv_args, const Tensor& in,
           bool is_dead) mutable {
@@ -141,7 +132,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
         };
 
         if (status.ok() && in.IsInitialized()) {
-          SameWorkerRecvDone(key, send_args, recv_args, in, out,
+          SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
                              std::move(final_callback));
         } else {
           final_callback(status);
@@ -149,8 +140,56 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
       });
 }
 
-void IntraProcessRendezvous::StartAbort(const Status& s) {
-  CHECK(!s.ok());
+}  // namespace
+
+RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
+    const DeviceMgr* device_mgr)
+    : device_mgr_(device_mgr) {}
+
+RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
+
+Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
+                                              const Rendezvous::Args& args,
+                                              const Tensor& val,
+                                              const bool is_dead) {
+  VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
+  return local_.Send(key, args, val, is_dead);
+}
+
+void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
+                                                 const Rendezvous::Args& args,
+                                                 DoneCallback done) {
+  VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
+  IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
+}
+
+void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
+  local_.StartAbort(s);
+}
+
+PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
+    const DeviceMgr* device_mgr)
+    : device_mgr_(device_mgr) {}
+
+PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
+
+Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
+                                           const Rendezvous::Args& args,
+                                           const Tensor& val,
+                                           const bool is_dead) {
+  DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
+  return local_.Send(key, args, val, is_dead);
+}
+
+void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
+                                              const Rendezvous::Args& args,
+                                              DoneCallback done) {
+  DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
+           << key.FullKey();
+  IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
+}
+
+void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
   local_.StartAbort(s);
 }
 
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h
index a9d3de122f0..eea5fbe388c 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.h
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.h
@@ -30,48 +30,61 @@ limitations under the License.
 
 namespace tensorflow {
 
-// IntraProcessRendezvous is a Rendezvous which expects all producers
-// and consumers to be devices immediately accessible within the
-// process. That is, it will never be necessary to perform an RPC to
+// The IntraProcessRendezvous classes are implementations of a Rendezvous that
+// expects all producers and consumers to be devices immediately accessible
+// within the process. That is, it will never be necessary to perform an RPC to
 // communicate with either.
 //
-// Buffering of Tensor values is delegated to a `LocalRendezvous`. This class
-// just adds functionality to coordinate multiple process-local devices.
-class IntraProcessRendezvous : public Rendezvous {
- public:
-  explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
+// Buffering of Tensor values is delegated to a `LocalRendezvous`. An
+// IntraProcessRendezvous. just adds functionality to coordinate multiple
+// process-local devices.
 
-  // Forwards to local_, where the Tensor "val" will be buffered and
-  // any waiting callback stored.
+// Reference-counted implementation that may be shared between multiple threads.
+class RefCountedIntraProcessRendezvous : public Rendezvous {
+ public:
+  explicit RefCountedIntraProcessRendezvous(const DeviceMgr* device_mgr);
+
+  // Implementation of RendezvousInterface methods.
   Status Send(const ParsedKey& key, const Rendezvous::Args& args,
               const Tensor& val, const bool is_dead) override;
-
-  // This method is called only by the RecvOp.  It tests to see
-  // whether the value will be produced by a local or remote device
-  // and handles accordingly.  In the local case it forwards to
-  // local_, in the remote case it initiates an RPC request.
   void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
                  DoneCallback done) override;
-
   void StartAbort(const Status& status) override;
 
  private:
   const DeviceMgr* device_mgr_;
   LocalRendezvous local_;
 
-  ~IntraProcessRendezvous() override;
+  ~RefCountedIntraProcessRendezvous() override;
 
-  // Callback handling the case when a rendezvous has been
-  // accomplished in local_ and the consumer is local to this process.
-  // Tensor "in" will be copied into "out". The key "parsed" encodes
-  // the src and dst devices.
-  typedef std::function StatusCallback;
-  void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
-                          const Rendezvous::Args& send_args,
-                          const Rendezvous::Args& recv_args, const Tensor& in,
-                          Tensor* out, StatusCallback done);
+  TF_DISALLOW_COPY_AND_ASSIGN(RefCountedIntraProcessRendezvous);
+};
 
-  TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
+// RefCountedIntraProcessRendezvous is aliased to IntraProcessRendezvous for
+// backwards compatibility with existing users.
+using IntraProcessRendezvous = RefCountedIntraProcessRendezvous;
+
+// Non-reference-counted implementation that may be stack-allocated for
+// performance.
+//
+// Prefer to use PrivateIntraProcessRendezvous in new code.
+class PrivateIntraProcessRendezvous : public RendezvousInterface {
+ public:
+  explicit PrivateIntraProcessRendezvous(const DeviceMgr* device_mgr);
+  ~PrivateIntraProcessRendezvous() override;
+
+  // Implementation of RendezvousInterface methods.
+  Status Send(const ParsedKey& key, const Rendezvous::Args& args,
+              const Tensor& val, const bool is_dead) override;
+  void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
+                 DoneCallback done) override;
+  void StartAbort(const Status& status) override;
+
+ private:
+  const DeviceMgr* device_mgr_;
+  LocalRendezvous local_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(PrivateIntraProcessRendezvous);
 };
 
 }  // end namespace tensorflow

From 74faaeb08fc5e06acadc57aa5cb3e8fb2a809b08 Mon Sep 17 00:00:00 2001
From: Eugene Zhulenev 
Date: Wed, 27 Nov 2019 15:54:13 -0800
Subject: [PATCH 073/279] Add benchmarks for Conv2D input gradient with strides
 not equal to one

PiperOrigin-RevId: 282847663
Change-Id: Iaa201f4f8e74a62380e68167dde75a158689baf8
---
 .../conv_grad_input_ops_benchmark_test.cc     | 74 +++++++++++--------
 1 file changed, 42 insertions(+), 32 deletions(-)

diff --git a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc
index 70a08b2496c..713c935dcf7 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include 
 #include 
 
 #include "tensorflow/cc/ops/standard_ops.h"
@@ -40,7 +39,7 @@ template 
 static Graph* Conv2DBackpropInput(int batch, int height, int width,
                                   int in_depth, int filter_h, int filter_w,
                                   int out_depth, int stride_h, int stride_w,
-                                  TensorFormat data_format) {
+                                  Padding padding, TensorFormat data_format) {
   auto* graph = new Graph(OpRegistry::Global());
 
   Tensor input_t = data_format == FORMAT_NHWC
@@ -53,7 +52,7 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
   Conv2DParameters params;
   params.dilations = {1, 1, 1, 1};
   params.strides = {1, stride_h, stride_w, 1};
-  params.padding = Padding::SAME;
+  params.padding = padding;
   params.data_format = data_format;
 
   Conv2DDimensions conv2d_dims;
@@ -85,7 +84,9 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
           .Input(backprop)
           .Attr("T", DataTypeToEnum::value)
           .Attr("strides", {1, stride_h, stride_w, 1})
-          .Attr("padding", "SAME")
+          .Attr("padding", padding == Padding::SAME
+                               ? "SAME"
+                               : padding == Padding::VALID ? "VALID" : "N/A")
           .Attr("data_format", ToString(data_format))
           .Finalize(graph, &conv2d));
 
@@ -94,7 +95,7 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
 
 // Macro arguments names: --------------------------------------------------- //
 //      T: data type
-// FORMAT: data format (NHWC or NCHW)
+// FMT: data format (NHWC or NCHW)
 //      N: batch size
 //      H: height
 //      W: width
@@ -107,41 +108,50 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
 
 #define BM_CONCAT(a, b) a##_##b
 
-#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FH, FW, FC, SH, SW) \
-  BM_CONCAT(name##_##T##_##FORMAT##_##type##_in##N##x##H##x##W##x##C,  \
-            f##FH##x##FW##x##FC##_##s##SH##x##SW)
+#define BM_NAME(name, type, T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING) \
+  BM_CONCAT(name##_##T##_##FMT##_##type##_in##N##x##H##x##W##x##C,           \
+            f##FH##x##FW##x##FC##_##s##SH##x##SW##_##PADDING)
 
-#define BM_Conv2DBwdInputFmt(T, FORMAT, N, H, W, C, FW, FH, FC, SH, SW, type)  \
-  static void BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FH, \
-                      FW, FC, SH, SW)(int iters) {                             \
-    testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) *      \
-                            (C));                                              \
-    test::Benchmark(#type, Conv2DBackpropInput(N, H, W, C, FH, FW, FC, SH,  \
-                                                  SW, FORMAT_##FORMAT))        \
-        .Run(iters);                                                           \
-  }                                                                            \
-  BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FH,   \
-                    FW, FC, SH, SW));
+#define BM_Conv2DBwdInput(T, FMT, N, H, W, C, FW, FH, FC, SH, SW, PADDING,    \
+                          type)                                               \
+  static void BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH,   \
+                      FW, FC, SH, SW, PADDING)(int iters) {                   \
+    testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) *     \
+                            (C));                                             \
+    test::Benchmark(#type, Conv2DBackpropInput(N, H, W, C, FH, FW, FC, SH, \
+                                                  SW, PADDING, FORMAT_##FMT)) \
+        .Run(iters);                                                          \
+  }                                                                           \
+  BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, FW, \
+                    FC, SH, SW, PADDING));
 
 using fp32 = float;
 using fp16 = Eigen::half;
 
 // ResNet50-ish convolutions.
-#define BENCHMARK_DTYPE(FORMAT, BATCH, T, D)                                \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, D);    \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, D);   \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, D);   \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, D);    \
+#define BENCHMARK_DTYPE(FMT, BATCH, T, D)                                   \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, SAME, D);  \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, SAME, D);  \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, SAME, D);  \
                                                                             \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, D);  \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, D);  \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, D);  \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, D);  \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, VALID, D); \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, VALID, D); \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, VALID, D); \
                                                                             \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, D);  \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, D); \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, D); \
-  BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, D);
+  BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, SAME, D);    \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, SAME, D);   \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, SAME, D);   \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, SAME, D);    \
+                                                                            \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, SAME, D);  \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, SAME, D);  \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, SAME, D);  \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, SAME, D);  \
+                                                                            \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, SAME, D);  \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, SAME, D); \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, SAME, D); \
+  BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, SAME, D);
 
 BENCHMARK_DTYPE(NHWC, 8, fp32, cpu);
 BENCHMARK_DTYPE(NHWC, 16, fp32, cpu);

From 23fde233bf3210759b5a4453bc39101df9c86d0c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Wed, 27 Nov 2019 15:54:18 -0800
Subject: [PATCH 074/279] Do mean reductions for integer types in 64 bit to
 mitigate overflow in the sum and/or denominator.

PiperOrigin-RevId: 282847676
Change-Id: I267823932b2c3e1f9916ea0edfcce3efb5d4430a
---
 tensorflow/core/kernels/reduction_ops.h       | 28 +++++++++++++++++++
 .../python/kernel_tests/reduction_ops_test.py | 21 ++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h
index 3c62dcfc081..46d8051fff1 100644
--- a/tensorflow/core/kernels/reduction_ops.h
+++ b/tensorflow/core/kernels/reduction_ops.h
@@ -72,6 +72,34 @@ struct ReduceEigenImpl                                           \
+  struct ReduceEigenImpl> {                  \
+    void operator()(const Device& d, OUT_T out, IN_T in,                      \
+                    const ReductionAxes& reduction_axes,                      \
+                    const functor::MeanReducer& reducer) {        \
+      static_assert(std::is_same::value,  \
+                    "");                                                      \
+      Eigen::internal::SumReducer sum_reducer;              \
+      out.device(d) = (in.template cast().reduce(           \
+                           reduction_axes, sum_reducer) /                     \
+                       static_cast(in.size() / out.size())) \
+                          .template cast();                       \
+    }                                                                         \
+  }
+
+CASTING_SPECIALIZATION(uint8, uint64);
+CASTING_SPECIALIZATION(uint16, uint64);
+CASTING_SPECIALIZATION(uint32, uint64);
+CASTING_SPECIALIZATION(int8, int64);
+CASTING_SPECIALIZATION(int16, int64);
+CASTING_SPECIALIZATION(int32, int64);
+#undef CASTING_SPECIALIZATION
+
 // TODO(rmlarsen): Refactor this such that taking the sqrt can be optional
 // controlled by an attribute.
 template 
Date: Wed, 27 Nov 2019 15:55:07 -0800
Subject: [PATCH 075/279] NFC: A few cleanups for SPIRVLowering

Updated comments and used static instead of anonymous namspace
to hide functions to be consistent with the existing codebase.

PiperOrigin-RevId: 282847784
Change-Id: I250d6692c0d6b21d467b2e6fe6540265236b3e10
---
 .../mlir/Dialect/SPIRV/SPIRVLowering.h        |  19 ++--
 .../mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp  | 104 +++++++++---------
 2 files changed, 63 insertions(+), 60 deletions(-)

diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index 8faa90cb134..306f2b9f309 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -30,21 +30,23 @@
 
 namespace mlir {
 
-/// Converts a function type according to the requirements of a SPIR-V entry
-/// function. The arguments need to be converted to spv.GlobalVariables of
-/// spv.ptr types so that they could be bound by the runtime.
+/// Type conversion from stdandard types to SPIR-V types for shader interface.
+///
+/// For composite types, this converter additionally performs type wrapping to
+/// satisfy shader interface requirements: shader interface types must be
+/// pointers to structs.
 class SPIRVTypeConverter final : public TypeConverter {
 public:
   using TypeConverter::TypeConverter;
 
-  /// Converts types to SPIR-V types using the basic type converter.
-  Type convertType(Type t) override;
+  /// Converts the given standard `type` to SPIR-V correspondance.
+  Type convertType(Type type) override;
 
-  /// Gets the index type equivalent in SPIR-V.
-  Type getIndexType(MLIRContext *context);
+  /// Gets the SPIR-V correspondance for the standard index type.
+  static Type getIndexType(MLIRContext *context);
 };
 
-/// Base class to define a conversion pattern to translate Ops into SPIR-V.
+/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.
 template 
 class SPIRVOpLowering : public OpConversionPattern {
 public:
@@ -54,7 +56,6 @@ public:
         typeConverter(typeConverter) {}
 
 protected:
-  /// Type lowering class.
   SPIRVTypeConverter &typeConverter;
 };
 
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 3c571add56a..baa9ed305aa 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -68,8 +68,7 @@ mlir::spirv::getEntryPointABIAttr(ArrayRef localSize,
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
-namespace {
-Type convertIndexType(MLIRContext *context) {
+Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
   // Convert to 32-bit integers for now. Might need a way to control this in
   // future.
   // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
@@ -82,7 +81,7 @@ Type convertIndexType(MLIRContext *context) {
 
 // TODO(ravishankarm): This is a utility function that should probably be
 // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
-Optional getTypeNumBytes(Type t) {
+static Optional getTypeNumBytes(Type t) {
   if (auto integerType = t.dyn_cast()) {
     return integerType.getWidth() / 8;
   } else if (auto floatType = t.dyn_cast()) {
@@ -92,17 +91,17 @@ Optional getTypeNumBytes(Type t) {
   return llvm::None;
 }
 
-Type typeConversionImpl(Type t) {
-  // Check if the type is SPIR-V supported. If so return the type.
-  if (spirv::SPIRVDialect::isValidType(t)) {
-    return t;
+static Type convertStdType(Type type) {
+  // If the type is already valid in SPIR-V, directly return.
+  if (spirv::SPIRVDialect::isValidType(type)) {
+    return type;
   }
 
-  if (auto indexType = t.dyn_cast()) {
-    return convertIndexType(t.getContext());
+  if (auto indexType = type.dyn_cast()) {
+    return SPIRVTypeConverter::getIndexType(type.getContext());
   }
 
-  if (auto memRefType = t.dyn_cast()) {
+  if (auto memRefType = type.dyn_cast()) {
     // TODO(ravishankarm): For now only support default memory space. The memory
     // space description is not set is stone within MLIR, i.e. it depends on the
     // context it is being used. To map this to SPIR-V storage classes, we
@@ -111,60 +110,65 @@ Type typeConversionImpl(Type t) {
     if (memRefType.getMemorySpace()) {
       return Type();
     }
-    auto elementType = typeConversionImpl(memRefType.getElementType());
+
+    auto elementType = convertStdType(memRefType.getElementType());
     if (!elementType) {
       return Type();
     }
+
     auto elementSize = getTypeNumBytes(elementType);
     if (!elementSize) {
       return Type();
     }
-    // TODO(ravishankarm) : Handle dynamic shapes.
-    if (memRefType.hasStaticShape()) {
-      // Get the strides and offset
-      int64_t offset;
-      SmallVector strides;
-      if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
-          offset == MemRefType::getDynamicStrideOrOffset() ||
-          llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
-        // TODO(ravishankarm) : Handle dynamic strides and offsets.
-        return Type();
-      }
-      // Convert to a multi-dimensional spv.array if size is known.
-      auto shape = memRefType.getShape();
-      assert(shape.size() == strides.size());
-      for (int i = shape.size(); i > 0; --i) {
-        elementType = spirv::ArrayType::get(
-            elementType, shape[i - 1], strides[i - 1] * elementSize.getValue());
-      }
-      // For the offset, need to wrap the array in a struct.
-      auto structType =
-          spirv::StructType::get(elementType, offset * elementSize.getValue());
-      // For now initialize the storage class to StorageBuffer. This will be
-      // updated later based on whats passed in w.r.t to the ABI attributes.
-      return spirv::PointerType::get(structType,
-                                     spirv::StorageClass::StorageBuffer);
+
+    if (!memRefType.hasStaticShape()) {
+      // TODO(ravishankarm) : Handle dynamic shapes.
+      return Type();
     }
+
+    // Get the strides and offset.
+    int64_t offset;
+    SmallVector strides;
+    if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
+        offset == MemRefType::getDynamicStrideOrOffset() ||
+        llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+      // TODO(ravishankarm) : Handle dynamic strides and offsets.
+      return Type();
+    }
+
+    // Convert to a multi-dimensional spv.array if size is known.
+    auto shape = memRefType.getShape();
+    assert(shape.size() == strides.size());
+    Type arrayType = elementType;
+    // TODO(antiagainst): Introduce layout as part of the shader ABI to have
+    // better separate of concerns.
+    for (int i = shape.size(); i > 0; --i) {
+      arrayType = spirv::ArrayType::get(
+          arrayType, shape[i - 1], strides[i - 1] * elementSize.getValue());
+    }
+
+    // For the offset, need to wrap the array in a struct.
+    auto structType =
+        spirv::StructType::get(arrayType, offset * elementSize.getValue());
+    // For now initialize the storage class to StorageBuffer. This will be
+    // updated later based on whats passed in w.r.t to the ABI attributes.
+    return spirv::PointerType::get(structType,
+                                   spirv::StorageClass::StorageBuffer);
   }
+
   return Type();
 }
-} // namespace
 
-Type SPIRVTypeConverter::convertType(Type t) { return typeConversionImpl(t); }
-
-Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
-  return convertType(IndexType::get(context));
-}
+Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); }
 
 //===----------------------------------------------------------------------===//
 // Builtin Variables
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// Look through all global variables in `moduleOp` and check if there is a
 /// spv.globalVariable that has the same `builtin` attribute.
-spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
-                                           spirv::BuiltIn builtin) {
+static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
+                                                  spirv::BuiltIn builtin) {
   for (auto varOp : moduleOp.getBlock().getOps()) {
     if (auto builtinAttr = varOp.getAttrOfType(convertToSnakeCase(
             stringifyDecoration(spirv::Decoration::BuiltIn)))) {
@@ -178,15 +182,14 @@ spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
 }
 
 /// Gets name of global variable for a buitlin.
-std::string getBuiltinVarName(spirv::BuiltIn builtin) {
+static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
 }
 
 /// Gets or inserts a global variable for a builtin within a module.
-spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp,
-                                                   Location loc,
-                                                   spirv::BuiltIn builtin,
-                                                   OpBuilder &builder) {
+static spirv::GlobalVariableOp
+getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
+                           spirv::BuiltIn builtin, OpBuilder &builder) {
   if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
     return varOp;
   }
@@ -217,7 +220,6 @@ spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp,
   builder.restoreInsertionPoint(ip);
   return newVarOp;
 }
-} // namespace
 
 /// Gets the global variable associated with a builtin and add
 /// it if it doesnt exist.

From 9e2427c588682fd86b1586f8ca83e4873370a0bd Mon Sep 17 00:00:00 2001
From: Robert David 
Date: Wed, 27 Nov 2019 16:00:40 -0800
Subject: [PATCH 076/279] Inline and optimize ApplyActivationToVector function.

Also implement missing activations.

PiperOrigin-RevId: 282848506
Change-Id: Iddba8efd6d28b2777347ebccdf8e410f6f65adb2
---
 tensorflow/lite/kernels/BUILD                 | 11 ---
 tensorflow/lite/kernels/activation_functor.h  | 58 --------------
 tensorflow/lite/kernels/internal/BUILD        |  1 -
 .../internal/optimized/neon_tensor_utils.h    |  9 ---
 .../internal/optimized/sse_tensor_utils.h     |  9 ---
 .../reference/portable_tensor_utils.cc        | 18 -----
 .../reference/portable_tensor_utils.h         |  9 ---
 .../reference/portable_tensor_utils_impl.h    |  9 ---
 .../lite/kernels/internal/tensor_utils.h      | 75 +++++++++++++++++--
 9 files changed, 70 insertions(+), 129 deletions(-)
 delete mode 100644 tensorflow/lite/kernels/activation_functor.h

diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 6bbc6561143..2c9e596f2f1 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -354,17 +354,6 @@ cc_test(
     ],
 )
 
-cc_library(
-    name = "activation_functor",
-    hdrs = [
-        "activation_functor.h",
-    ],
-    copts = tflite_copts(),
-    deps = [
-        "//tensorflow/lite/c:common",
-    ],
-)
-
 cc_library(
     name = "op_macros",
     hdrs = [
diff --git a/tensorflow/lite/kernels/activation_functor.h b/tensorflow/lite/kernels/activation_functor.h
deleted file mode 100644
index 60e93c185a9..00000000000
--- a/tensorflow/lite/kernels/activation_functor.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
-#define TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
-
-#include 
-#include 
-#include 
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-
-namespace tflite {
-
-// Dynamic (non-fused) activation functor. perhaps it is worth having
-// template instantiation?
-// TODO(aselle): Make this more efficient by pulling the switch to conv_eval
-// using template inlining.
-class ActivationFunctor {
- public:
-  explicit ActivationFunctor(TfLiteFusedActivation act) : act_(act) {}
-
-  float operator()(float a) const {
-    switch (act_) {
-      case kTfLiteActNone:
-        return a;
-      case kTfLiteActRelu:
-        return a < 0.f ? 0.f : a;
-      case kTfLiteActRelu6:
-        return std::max(0.f, std::min(a, 6.f));
-      case kTfLiteActTanh:
-        return std::tanh(a);
-      case kTfLiteActSigmoid:
-        return 1.0f / (1.0f + std::exp(-a));
-      default:
-        // TODO(aselle): More informative fatal error!
-        exit(1);
-    }
-  }
-
- private:
-  TfLiteFusedActivation act_;
-};
-
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index d8bb8b41fff..93beae158ba 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -577,7 +577,6 @@ cc_library(
         ":compatibility",
         ":round",
         "//tensorflow/lite/c:common",
-        "//tensorflow/lite/kernels:activation_functor",
         "//tensorflow/lite/kernels:cpu_backend_context",
         "@gemmlowp",
     ],
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
index c4d3d0e13be..626afbe5d8d 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -196,15 +196,6 @@ void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
   PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
 }
 
-void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
-  PortableApplySigmoidToVector(vector, v_size, result);
-}
-
-void ApplyActivationToVector(const float* vector, int v_size,
-                             TfLiteFusedActivation activation, float* result) {
-  PortableApplyActivationToVector(vector, v_size, activation, result);
-}
-
 void Sub1Vector(const float* vector, int v_size, float* result) {
   NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
 }
diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
index 3659b6f4e1a..37c1c5ce05a 100644
--- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
@@ -206,15 +206,6 @@ void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
   PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
 }
 
-void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
-  PortableApplySigmoidToVector(vector, v_size, result);
-}
-
-void ApplyActivationToVector(const float* vector, int v_size,
-                             TfLiteFusedActivation activation, float* result) {
-  PortableApplyActivationToVector(vector, v_size, activation, result);
-}
-
 void Sub1Vector(const float* vector, int v_size, float* result) {
   NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
 }
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
index dba6079009a..1ba34d45987 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -21,7 +21,6 @@ limitations under the License.
 
 #include "fixedpoint/fixedpoint.h"
 #include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/kernels/activation_functor.h"
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
 #include "tensorflow/lite/kernels/internal/common.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
@@ -591,23 +590,6 @@ void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
   }
 }
 
-void PortableApplySigmoidToVector(const float* vector, int v_size,
-                                  float* result) {
-  auto sigmoid_func = ActivationFunctor(kTfLiteActSigmoid);
-  for (int v = 0; v < v_size; v++) {
-    *result++ = (sigmoid_func)(*vector++);
-  }
-}
-
-void PortableApplyActivationToVector(const float* vector, int v_size,
-                                     TfLiteFusedActivation activation,
-                                     float* result) {
-  auto activation_func = ActivationFunctor(activation);
-  for (int v = 0; v < v_size; v++) {
-    *result++ = (activation_func)(*vector++);
-  }
-}
-
 void PortableSub1Vector(const float* vector, int v_size, float* result) {
   for (int v = 0; v < v_size; v++) {
     *result++ = 1.0f - *vector++;
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
index 9d8cf4e2b9a..587501fe2cb 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -229,15 +229,6 @@ void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
   PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
 }
 
-void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
-  PortableApplySigmoidToVector(vector, v_size, result);
-}
-
-void ApplyActivationToVector(const float* vector, int v_size,
-                             TfLiteFusedActivation activation, float* result) {
-  PortableApplyActivationToVector(vector, v_size, activation, result);
-}
-
 void Sub1Vector(const float* vector, int v_size, float* result) {
   PortableSub1Vector(vector, v_size, result);
 }
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
index ddc400bb0c9..954ef6716b6 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
@@ -171,15 +171,6 @@ void PortableVectorBatchVectorAssign(const float* vector, int v_size,
 void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
                                   float* batch_vector);
 
-// Apply sigmoid to elements of a vector.
-void PortableApplySigmoidToVector(const float* vector, int v_size,
-                                  float* result);
-
-// Apply activation function to elements of a vector.
-void PortableApplyActivationToVector(const float* vector, int v_size,
-                                     TfLiteFusedActivation activation,
-                                     float* result);
-
 // Compute "1.0f - elements of vector" (used in CIFG).
 void PortableSub1Vector(const float* vector, int v_size, float* result);
 
diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h
index 7121403532a..60b6c95f76c 100644
--- a/tensorflow/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/tensor_utils.h
@@ -16,6 +16,7 @@ limitations under the License.
 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
 
 #include 
+#include 
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
@@ -401,12 +402,76 @@ void VectorBatchVectorAssign(const T* vector, int v_size, int n_batch,
   }
 }
 
-// Apply sigmoid to elements of a vector.
-void ApplySigmoidToVector(const float* vector, int v_size, float* result);
+// Apply Rectified Linear to elements of a vector.
+inline void ApplyReluToVector(const float* __restrict__ vector, int v_size,
+                              float* __restrict__ result) {
+  for (int v = 0; v < v_size; v++) {
+    result[v] = std::max(0.0f, vector[v]);
+  }
+}
 
-// Apply activation function to elements of a vector.
-void ApplyActivationToVector(const float* vector, int v_size,
-                             TfLiteFusedActivation activation, float* result);
+// Apply Rectified Linear 1 (cap to [-1;1]) to elements of a vector
+inline void ApplyRelu1ToVector(const float* __restrict__ vector, int v_size,
+                               float* __restrict__ result) {
+  for (int v = 0; v < v_size; v++) {
+    result[v] = std::max(-1.0f, std::min(vector[v], 1.0f));
+  }
+}
+
+// Apply Rectified Linear 6 (cap to [0;6]) to elements of a vector
+inline void ApplyRelu6ToVector(const float* __restrict__ vector, int v_size,
+                               float* __restrict__ result) {
+  for (int v = 0; v < v_size; v++) {
+    result[v] = std::max(0.0f, std::min(vector[v], 6.0f));
+  }
+}
+
+// Apply tanh to elements of a vector
+inline void ApplyTanhToVector(const float* __restrict__ vector, int v_size,
+                              float* __restrict__ result) {
+  for (int v = 0; v < v_size; v++) {
+    result[v] = std::tanh(vector[v]);
+  }
+}
+
+// Apply signbit to elements of a vector
+inline void ApplySignbitToVector(const float* __restrict__ vector, int v_size,
+                                 float* __restrict__ result) {
+  for (int v = 0; v < v_size; v++) {
+    result[v] = std::signbit(vector[v]);
+  }
+}
+
+// Apply sigmoid to elements of a vector.
+inline void ApplySigmoidToVector(const float* __restrict__ vector, int v_size,
+                                 float* __restrict__ result) {
+  for (int v = 0; v < v_size; v++) {
+    result[v] = 1.0f / (1.0f + std::exp(-vector[v]));
+  }
+}
+
+// Apply appropriate activation function to elements of a vector.
+inline void ApplyActivationToVector(const float* __restrict__ vector,
+                                    int v_size,
+                                    TfLiteFusedActivation activation,
+                                    float* __restrict__ result) {
+  switch (activation) {
+    case kTfLiteActNone:
+      return;
+    case kTfLiteActRelu:
+      return ApplyReluToVector(vector, v_size, result);
+    case kTfLiteActRelu1:
+      return ApplyRelu1ToVector(vector, v_size, result);
+    case kTfLiteActRelu6:
+      return ApplyRelu6ToVector(vector, v_size, result);
+    case kTfLiteActTanh:
+      return ApplyTanhToVector(vector, v_size, result);
+    case kTfLiteActSignBit:
+      return ApplySignbitToVector(vector, v_size, result);
+    case kTfLiteActSigmoid:
+      return ApplySigmoidToVector(vector, v_size, result);
+  }
+}
 
 // Compute "1.0f - elements of vector" (used in CIFG).
 void Sub1Vector(const float* vector, int v_size, float* result);

From 0d730f15370ca90036c02fff5ebf406f4772d4ef Mon Sep 17 00:00:00 2001
From: Eugene Zhulenev 
Date: Wed, 27 Nov 2019 16:46:24 -0800
Subject: [PATCH 077/279] Add benchmarks for Conv2D filter gradient with
 strides not equal to one

PiperOrigin-RevId: 282854489
Change-Id: Ib80cd9bc2e7458b6972f2e1225156f1c0bb123d8
---
 .../conv_grad_filter_ops_benchmark_test.cc    | 63 +++++++++++--------
 1 file changed, 37 insertions(+), 26 deletions(-)

diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc
index 9b168045047..97148945331 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc
@@ -40,7 +40,7 @@ template 
 static Graph* Conv2DBackpropFilter(int batch, int height, int width,
                                    int in_depth, int filter_h, int filter_w,
                                    int out_depth, int stride_h, int stride_w,
-                                   TensorFormat data_format) {
+                                   Padding padding, TensorFormat data_format) {
   auto* graph = new Graph(OpRegistry::Global());
 
   Tensor input_t = data_format == FORMAT_NHWC
@@ -53,7 +53,7 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width,
   Conv2DParameters params;
   params.dilations = {1, 1, 1, 1};
   params.strides = {1, stride_h, stride_w, 1};
-  params.padding = Padding::SAME;
+  params.padding = padding;
   params.data_format = data_format;
 
   Conv2DDimensions conv2d_dims;
@@ -85,7 +85,9 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width,
           .Input(backprop)
           .Attr("T", DataTypeToEnum::value)
           .Attr("strides", {1, stride_h, stride_w, 1})
-          .Attr("padding", "SAME")
+          .Attr("padding", padding == Padding::SAME
+                               ? "SAME"
+                               : padding == Padding::VALID ? "VALID" : "N/A")
           .Attr("data_format", ToString(data_format))
           .Finalize(graph, &conv2d));
 
@@ -94,7 +96,7 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width,
 
 // Macro arguments names: --------------------------------------------------- //
 //      T: data type
-// FORMAT: data format (NHWC or NCHW)
+// FMT: data format (NHWC or NCHW)
 //      N: batch size
 //      H: height
 //      W: width
@@ -107,38 +109,47 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width,
 
 #define BM_CONCAT(a, b) a##_##b
 
-#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FH, FW, FC, SH, SW) \
-  BM_CONCAT(name##_##T##_##FORMAT##_##type##_in##N##x##H##x##W##x##C,  \
-            f##FH##x##FW##x##FC##_##s##SH##x##SW)
+#define BM_NAME(name, type, T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING) \
+  BM_CONCAT(name##_##T##_##FMT##_##type##_in##N##x##H##x##W##x##C,           \
+            f##FH##x##FW##x##FC##_##s##SH##x##SW##_##PADDING)
 
-#define BM_Conv2DBwdFilterFmt(T, FORMAT, N, H, W, C, FH, FW, FC, SH, SW, type) \
-  static void BM_NAME(BM_Conv2DBackpropFilter, type, T, FORMAT, N, H, W, C,    \
-                      FH, FW, FC, SH, SW)(int iters) {                         \
+#define BM_Conv2DBwdFilter(T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING,    \
+                           type)                                               \
+  static void BM_NAME(BM_Conv2DBackpropFilter, type, T, FMT, N, H, W, C, FH,   \
+                      FW, FC, SH, SW, PADDING)(int iters) {                    \
     testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) *      \
                             (C));                                              \
     test::Benchmark(#type, Conv2DBackpropFilter(N, H, W, C, FH, FW, FC, SH, \
-                                                   SW, FORMAT_##FORMAT))       \
+                                                   SW, PADDING, FORMAT_##FMT)) \
         .Run(iters);                                                           \
   }                                                                            \
-  BENCHMARK(BM_NAME(BM_Conv2DBackpropFilter, type, T, FORMAT, N, H, W, C, FH,  \
-                    FW, FC, SH, SW));
+  BENCHMARK(BM_NAME(BM_Conv2DBackpropFilter, type, T, FMT, N, H, W, C, FH, FW, \
+                    FC, SH, SW, PADDING));
 
 // ResNet50-ish convolutions.
-#define BENCHMARK_DTYPE(FORMAT, BATCH, T, D)                                 \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, D);    \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, D);   \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, D);   \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, D);    \
+#define BENCHMARK_DTYPE(FMT, BATCH, T, D)                                    \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, SAME, D);  \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, SAME, D);  \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, SAME, D);  \
                                                                              \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, D);  \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, D);  \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, D);  \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, D);  \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, VALID, D); \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, VALID, D); \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, VALID, D); \
                                                                              \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, D);  \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, D); \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, D); \
-  BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, D);
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, SAME, D);    \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, SAME, D);   \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, SAME, D);   \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, SAME, D);    \
+                                                                             \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, SAME, D);  \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, SAME, D);  \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, SAME, D);  \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, SAME, D);  \
+                                                                             \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, SAME, D);  \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, SAME, D); \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, SAME, D); \
+  BM_Conv2DBwdFilter(T, FMT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, SAME, D);
 
 using fp32 = float;
 using fp16 = Eigen::half;

From f050412ecddfb771008165989946dcea3b9b60f8 Mon Sep 17 00:00:00 2001
From: Derek Murray 
Date: Wed, 27 Nov 2019 16:50:21 -0800
Subject: [PATCH 078/279] Lazily construct no-op
 OpKernelContext::Params::{inc,dec}_num_deferred_ops_function.

Each time we create an OpKernelContext::Params, we default-create no-op functions for these members. Since these functions are rarely used, this change defers their creation until the point of use.

PiperOrigin-RevId: 282854876
Change-Id: Ibdf5c034cffb001d2055413b29c328386b011693
---
 tensorflow/core/framework/op_kernel.h | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 7275fb0484d..7f9895f7771 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -726,8 +726,8 @@ class OpKernelContext {
     const int* forward_from_array = nullptr;
 
     // For tracking actively running deferred ops.
-    std::function inc_num_deferred_ops_function = []() {};
-    std::function dec_num_deferred_ops_function = []() {};
+    std::function inc_num_deferred_ops_function;
+    std::function dec_num_deferred_ops_function;
   };
 
   // params must outlive the OpKernelContext.
@@ -1271,10 +1271,14 @@ class OpKernelContext {
   // functions. It then must call these two functions in pairs, before and after
   // device execution, respectively.
   TF_MUST_USE_RESULT std::function inc_num_deferred_ops_function() {
-    return params_->inc_num_deferred_ops_function;
+    return params_->inc_num_deferred_ops_function
+               ? params_->inc_num_deferred_ops_function
+               : []() {};
   }
   TF_MUST_USE_RESULT std::function dec_num_deferred_ops_function() {
-    return params_->dec_num_deferred_ops_function;
+    return params_->dec_num_deferred_ops_function
+               ? params_->dec_num_deferred_ops_function
+               : []() {};
   }
 
   Allocator* get_allocator(AllocatorAttributes attr);

From 20b978a0144b7a7d9fc103e49830523ac117a1fe Mon Sep 17 00:00:00 2001
From: Derek Murray 
Date: Wed, 27 Nov 2019 16:58:34 -0800
Subject: [PATCH 079/279] Rolling back "Add PrivateIntraProcessRendezvous." due
 to a data race in ExecutorBarrier.

PiperOrigin-RevId: 282855852
Change-Id: I7507b5f40cf71ab16ef338a4ddad93fd7588577a
---
 .../core/common_runtime/direct_session.cc     |  9 +-
 .../core/common_runtime/direct_session.h      |  4 +-
 tensorflow/core/common_runtime/executor.h     | 10 +-
 tensorflow/core/common_runtime/function.cc    | 12 +--
 .../core/common_runtime/function_test.cc      |  5 +-
 .../process_function_library_runtime_test.cc  | 39 ++++----
 .../core/common_runtime/rendezvous_mgr.cc     | 97 ++++++-------------
 .../core/common_runtime/rendezvous_mgr.h      | 65 +++++--------
 8 files changed, 98 insertions(+), 143 deletions(-)

diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 5d42ec208f7..133a6c31a93 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -521,8 +521,7 @@ Status DirectSession::RunInternal(
                             executor_step_count, &debugger_state));
   }
 
-  PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
-
+  run_state.rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
 #ifndef __ANDROID__
   // Set up for collectives if ExecutorsAndKeys declares a key.
   if (executors_and_keys->collective_graph_key !=
@@ -617,7 +616,7 @@ Status DirectSession::RunInternal(
   Executor::Args args;
   args.step_id = step_id;
   args.call_frame = call_frame;
-  args.rendezvous = &rendezvous;
+  args.rendezvous = run_state.rendez.get();
   args.collective_executor =
       (run_state.collective_executor ? run_state.collective_executor->get()
                                      : nullptr);
@@ -696,7 +695,7 @@ Status DirectSession::RunInternal(
     // `barrier` will delete itself after the final executor finishes.
     Notification executors_done;
     ExecutorBarrier* barrier =
-        new ExecutorBarrier(num_executors, &rendezvous,
+        new ExecutorBarrier(num_executors, run_state.rendez.get(),
                             [&run_state, &executors_done](const Status& ret) {
                               {
                                 mutex_lock l(run_state.mu);
@@ -1140,7 +1139,7 @@ Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
 
 Status DirectSession::RecvPRunOutputs(
     const std::vector& output_names,
-    const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
+    const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
     std::vector* outputs) {
   Status s;
   if (!output_names.empty()) {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 7bbb198ef44..a272633b4e2 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -191,6 +191,7 @@ class DirectSession : public Session {
   struct RunState {
     mutex mu;
     Status status GUARDED_BY(mu);
+    core::RefCountPtr rendez = nullptr;
     std::unique_ptr collective_executor;
     std::unique_ptr collector;
     TensorStore tensor_store;
@@ -207,7 +208,6 @@ class DirectSession : public Session {
     Notification executors_done;
     std::unordered_map pending_inputs;   // true if fed
     std::unordered_map pending_outputs;  // true if fetched
-    core::RefCountPtr rendez = nullptr;
 
     PartialRunState(const std::vector& pending_input_names,
                     const std::vector& pending_output_names,
@@ -282,7 +282,7 @@ class DirectSession : public Session {
   // tensors are computed.
   ::tensorflow::Status RecvPRunOutputs(
       const std::vector& output_names,
-      const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
+      const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
       std::vector* outputs);
 
   // Check if the specified fetches can be computed from the feeds
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 8e6e9bd9336..c147deee694 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -166,8 +166,8 @@ class ExecutorBarrier {
   //
   // 'done' is called after the last executor completes, and
   // ExecutorBarrier is deleted.
-  ExecutorBarrier(size_t num, RendezvousInterface* r, StatusCallback done)
-      : rendez_(r), done_cb_(std::move(done)), pending_(num) {}
+  ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
+      : rendez_(r), done_cb_(done), pending_(num) {}
 
   ~ExecutorBarrier() {}
 
@@ -178,7 +178,7 @@ class ExecutorBarrier {
   }
 
  private:
-  RendezvousInterface* rendez_ = nullptr;  // Not owned.
+  Rendezvous* rendez_ = nullptr;
   StatusCallback done_cb_ = nullptr;
 
   mutable mutex mu_;
@@ -186,7 +186,7 @@ class ExecutorBarrier {
   StatusGroup status_group_ GUARDED_BY(mu_);
 
   void WhenDone(const Status& s) {
-    RendezvousInterface* error_rendez = nullptr;
+    Rendezvous* error_rendez = nullptr;
     StatusCallback done = nullptr;
     Status status;
 
@@ -197,6 +197,7 @@ class ExecutorBarrier {
       // Rendezvous object by this thread only.
       if (status_group_.ok() && !s.ok()) {
         error_rendez = rendez_;
+        error_rendez->Ref();
       }
 
       if (!s.ok() && !StatusGroup::IsDerived(s) &&
@@ -218,6 +219,7 @@ class ExecutorBarrier {
     if (error_rendez != nullptr) {
       error_rendez->StartAbort(
           errors::Aborted("Stopping remaining executors."));
+      error_rendez->Unref();
     }
 
     if (done != nullptr) {
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 501002e1f7f..aa3be38fd29 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1116,11 +1116,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
   }
   Options run_opts = opts;
   if (opts.create_rendezvous) {
-    auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
+    Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
     run_opts.rendezvous = rendezvous;
     run_opts.create_rendezvous = false;
-    done = [done = std::move(done), rendezvous](const Status& status) mutable {
-      delete rendezvous;
+    done = [done = std::move(done), rendezvous](const Status& status) {
+      rendezvous->Unref();
       done(status);
     };
   }
@@ -1187,11 +1187,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
 
   Options run_opts = opts;
   if (opts.create_rendezvous) {
-    auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
+    Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
     run_opts.rendezvous = rendezvous;
     run_opts.create_rendezvous = false;
-    done = [done = std::move(done), rendezvous](const Status& status) mutable {
-      delete rendezvous;
+    done = [done = std::move(done), rendezvous](const Status& status) {
+      rendezvous->Unref();
       done(status);
     };
   }
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 89e4daa50b3..7c76c469d1e 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -1854,8 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
 
   Tensor y;
   FunctionLibraryRuntime::Options opts;
-  PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
-  opts.rendezvous = &rendezvous;
+  Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_.get());
+  opts.rendezvous = rendezvous;
   opts.source_device = "/device:CPU:1";
   // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
   TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
@@ -1870,6 +1870,7 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
       y,
       test::AsTensor({"/job:localhost/replica:0/task:0/device:CPU:1"},
                               TensorShape({})));
+  rendezvous->Unref();
 }
 
 namespace {
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 55bc408f9c5..1a5ed3caa11 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -110,6 +110,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
     }
   }
 
+  ~ProcessFunctionLibraryRuntimeTest() override {
+    if (rendezvous_ != nullptr) {
+      rendezvous_->Unref();
+    }
+  }
+
   void Init(const std::vector& flib,
             const SessionMetadata* session_metadata = nullptr) {
     FunctionDefLibrary proto;
@@ -121,8 +127,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
         device_mgr_.get(), Env::Default(), /*config=*/nullptr,
         TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(),
         nullptr, session_metadata));
-    rendezvous_ =
-        absl::make_unique(device_mgr_.get());
+    rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
   }
 
   Status Instantiate(
@@ -258,7 +263,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
           test::function::FunctionTestSchedClosure(fn);
         };
 
-    opts.rendezvous = rendezvous_.get();
+    opts.rendezvous = rendezvous_;
     opts.runner = &runner;
     Status status;
     Notification done;
@@ -287,7 +292,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
   std::unique_ptr lib_def_;
   std::unique_ptr cluster_flr_;
   std::unique_ptr proc_flr_;
-  std::unique_ptr rendezvous_ = nullptr;
+  IntraProcessRendezvous* rendezvous_ = nullptr;
 };
 
 TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
@@ -339,7 +344,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
   Init({test::function::XTimesTwo()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -354,7 +359,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -370,7 +375,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
   auto x = test::AsTensor({1, 2, 3, 4});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -387,7 +392,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:1";
@@ -406,7 +411,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   Tensor y;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts_0;
@@ -427,7 +432,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
@@ -457,7 +462,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
@@ -504,7 +509,7 @@ void TestTwoDeviceMult(
     const string& error = "") {
   fixture->Init({test::function::TwoDeviceMult()});
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = fixture->rendezvous_.get();
+  opts.rendezvous = fixture->rendezvous_;
   auto x = test::AsTensor({1, 2, 3});
   Tensor y_cpu;
   Tensor y_gpu;
@@ -537,7 +542,7 @@ void TestTwoDeviceInputOutput(
   fixture->Init({test::function::TwoDeviceInputOutput()});
 
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = fixture->rendezvous_.get();
+  opts.rendezvous = fixture->rendezvous_;
   Tensor x1 = test::AsTensor({1, 2});
   if (absl::StrContains(inst_opts.input_devices[0], "GPU")) {
     x1 = fixture->CPUToGPU(x1);
@@ -738,7 +743,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
 
   // Run the function taking a resource and outputing it
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   Tensor x1 = CPUToGPU(test::AsTensor({1, 2}));
   Tensor x2 = GetResourceHandle("my_gpu_var", mgr->default_container(),
                                 "/job:a/replica:0/task:0/device:GPU:0");
@@ -980,7 +985,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataAbsent) {
   Init({SessionMetadataReaderOpFn()}, /*session_metadata=*/nullptr);
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -996,7 +1001,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
   Init({SessionMetadataReaderOpFn()}, &session_metadata);
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -1022,7 +1027,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
   TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_.get();
+  opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc
index 6ed7df2cc1e..0d5e79667db 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.cc
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc
@@ -32,12 +32,23 @@ limitations under the License.
 
 namespace tensorflow {
 
-namespace {
-void SameWorkerRecvDone(const DeviceMgr* device_mgr,
-                        const Rendezvous::ParsedKey& parsed,
-                        const Rendezvous::Args& send_args,
-                        const Rendezvous::Args& recv_args, const Tensor& in,
-                        Tensor* out, StatusCallback done) {
+IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
+    : device_mgr_(device_mgr) {}
+
+IntraProcessRendezvous::~IntraProcessRendezvous() {}
+
+Status IntraProcessRendezvous::Send(const ParsedKey& key,
+                                    const Rendezvous::Args& args,
+                                    const Tensor& val, const bool is_dead) {
+  VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
+  // Buffers "val" and "device_context" in local_.
+  return local_.Send(key, args, val, is_dead);
+}
+
+void IntraProcessRendezvous::SameWorkerRecvDone(
+    const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
+    const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
+    StatusCallback done) {
   // Do a quick copy (sharing the underlying buffer) if both tensors
   // are on host memory.
   const bool src_host =
@@ -62,13 +73,13 @@ void SameWorkerRecvDone(const DeviceMgr* device_mgr,
   }
 
   Device* src_device;
-  Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
+  Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
   if (!s.ok()) {
     done(s);
     return;
   }
   Device* dst_device;
-  s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+  s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
   if (!s.ok()) {
     done(s);
     return;
@@ -105,18 +116,16 @@ void SameWorkerRecvDone(const DeviceMgr* device_mgr,
       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
 }
 
-void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
-                               LocalRendezvous* local,
-                               const RendezvousInterface::ParsedKey& parsed,
-                               const Rendezvous::Args& recv_args,
-                               RendezvousInterface::DoneCallback done) {
-  VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
+void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
+                                       const Rendezvous::Args& args,
+                                       DoneCallback done) {
+  VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
 
   MEMDEBUG_CACHE_OP("RecvAsync");
   // Recv the tensor from local_.
-  local->RecvAsync(
-      parsed, recv_args,
-      [device_mgr, parsed, done = std::move(done)](
+  local_.RecvAsync(
+      key, args,
+      [this, key, done = std::move(done)](
           const Status& status, const Rendezvous::Args& send_args,
           const Rendezvous::Args& recv_args, const Tensor& in,
           bool is_dead) mutable {
@@ -132,7 +141,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
         };
 
         if (status.ok() && in.IsInitialized()) {
-          SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
+          SameWorkerRecvDone(key, send_args, recv_args, in, out,
                              std::move(final_callback));
         } else {
           final_callback(status);
@@ -140,56 +149,8 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
       });
 }
 
-}  // namespace
-
-RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
-    const DeviceMgr* device_mgr)
-    : device_mgr_(device_mgr) {}
-
-RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
-
-Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
-                                              const Rendezvous::Args& args,
-                                              const Tensor& val,
-                                              const bool is_dead) {
-  VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
-  return local_.Send(key, args, val, is_dead);
-}
-
-void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
-                                                 const Rendezvous::Args& args,
-                                                 DoneCallback done) {
-  VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
-  IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
-}
-
-void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
-  local_.StartAbort(s);
-}
-
-PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
-    const DeviceMgr* device_mgr)
-    : device_mgr_(device_mgr) {}
-
-PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
-
-Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
-                                           const Rendezvous::Args& args,
-                                           const Tensor& val,
-                                           const bool is_dead) {
-  DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
-  return local_.Send(key, args, val, is_dead);
-}
-
-void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
-                                              const Rendezvous::Args& args,
-                                              DoneCallback done) {
-  DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
-           << key.FullKey();
-  IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
-}
-
-void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
+void IntraProcessRendezvous::StartAbort(const Status& s) {
+  CHECK(!s.ok());
   local_.StartAbort(s);
 }
 
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h
index eea5fbe388c..a9d3de122f0 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.h
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.h
@@ -30,61 +30,48 @@ limitations under the License.
 
 namespace tensorflow {
 
-// The IntraProcessRendezvous classes are implementations of a Rendezvous that
-// expects all producers and consumers to be devices immediately accessible
-// within the process. That is, it will never be necessary to perform an RPC to
+// IntraProcessRendezvous is a Rendezvous which expects all producers
+// and consumers to be devices immediately accessible within the
+// process. That is, it will never be necessary to perform an RPC to
 // communicate with either.
 //
-// Buffering of Tensor values is delegated to a `LocalRendezvous`. An
-// IntraProcessRendezvous. just adds functionality to coordinate multiple
-// process-local devices.
-
-// Reference-counted implementation that may be shared between multiple threads.
-class RefCountedIntraProcessRendezvous : public Rendezvous {
+// Buffering of Tensor values is delegated to a `LocalRendezvous`. This class
+// just adds functionality to coordinate multiple process-local devices.
+class IntraProcessRendezvous : public Rendezvous {
  public:
-  explicit RefCountedIntraProcessRendezvous(const DeviceMgr* device_mgr);
+  explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
 
-  // Implementation of RendezvousInterface methods.
+  // Forwards to local_, where the Tensor "val" will be buffered and
+  // any waiting callback stored.
   Status Send(const ParsedKey& key, const Rendezvous::Args& args,
               const Tensor& val, const bool is_dead) override;
+
+  // This method is called only by the RecvOp.  It tests to see
+  // whether the value will be produced by a local or remote device
+  // and handles accordingly.  In the local case it forwards to
+  // local_, in the remote case it initiates an RPC request.
   void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
                  DoneCallback done) override;
+
   void StartAbort(const Status& status) override;
 
  private:
   const DeviceMgr* device_mgr_;
   LocalRendezvous local_;
 
-  ~RefCountedIntraProcessRendezvous() override;
+  ~IntraProcessRendezvous() override;
 
-  TF_DISALLOW_COPY_AND_ASSIGN(RefCountedIntraProcessRendezvous);
-};
+  // Callback handling the case when a rendezvous has been
+  // accomplished in local_ and the consumer is local to this process.
+  // Tensor "in" will be copied into "out". The key "parsed" encodes
+  // the src and dst devices.
+  typedef std::function StatusCallback;
+  void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
+                          const Rendezvous::Args& send_args,
+                          const Rendezvous::Args& recv_args, const Tensor& in,
+                          Tensor* out, StatusCallback done);
 
-// RefCountedIntraProcessRendezvous is aliased to IntraProcessRendezvous for
-// backwards compatibility with existing users.
-using IntraProcessRendezvous = RefCountedIntraProcessRendezvous;
-
-// Non-reference-counted implementation that may be stack-allocated for
-// performance.
-//
-// Prefer to use PrivateIntraProcessRendezvous in new code.
-class PrivateIntraProcessRendezvous : public RendezvousInterface {
- public:
-  explicit PrivateIntraProcessRendezvous(const DeviceMgr* device_mgr);
-  ~PrivateIntraProcessRendezvous() override;
-
-  // Implementation of RendezvousInterface methods.
-  Status Send(const ParsedKey& key, const Rendezvous::Args& args,
-              const Tensor& val, const bool is_dead) override;
-  void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
-                 DoneCallback done) override;
-  void StartAbort(const Status& status) override;
-
- private:
-  const DeviceMgr* device_mgr_;
-  LocalRendezvous local_;
-
-  TF_DISALLOW_COPY_AND_ASSIGN(PrivateIntraProcessRendezvous);
+  TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
 };
 
 }  // end namespace tensorflow

From ff731ce2f326bc4335094af5e28f88ee2d74f54c Mon Sep 17 00:00:00 2001
From: Robert David 
Date: Wed, 27 Nov 2019 17:08:54 -0800
Subject: [PATCH 080/279] Use Eigen in ApplyTanhToVector and
 ApplySigmoidToVector.

Remove the "override" that did this already in hybrid LSTM.

The Eigen implementation has slightly worse accuracy, but is implemented using SIMD.

PiperOrigin-RevId: 282857162
Change-Id: Iddd383d163773e8ca72ee51e02d2cb7cb249a82f
---
 tensorflow/lite/kernels/BUILD                 |  1 -
 tensorflow/lite/kernels/internal/BUILD        |  6 +--
 .../lite/kernels/internal/tensor_utils.h      | 15 ++++---
 tensorflow/lite/kernels/lstm_eval.cc          | 43 +++++--------------
 4 files changed, 20 insertions(+), 45 deletions(-)

diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 2c9e596f2f1..b3657228e63 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -603,7 +603,6 @@ cc_library(
         "//tensorflow/lite/kernels/internal:kernel_utils",
         "//tensorflow/lite/kernels/internal:tensor",
         "//tensorflow/lite/kernels/internal:tensor_utils",
-        "//third_party/eigen3",
         "@gemmlowp",
     ],
 )
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 93beae158ba..646f14680ac 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -669,14 +669,10 @@ cc_library(
     ],
     copts = tflite_copts() + NEON_FLAGS_IF_APPLICABLE,
     deps = [
-        ":common",
-        ":compatibility",
         ":cpu_check",
-        ":types",
+        "//third_party/eigen3",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels:cpu_backend_context",
-        "//tensorflow/lite/kernels:op_macros",
-        "@gemmlowp//:fixedpoint",
     ] + select({
         ":aarch64": [
             ":neon_tensor_utils",
diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h
index 60b6c95f76c..fccd058bea5 100644
--- a/tensorflow/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/tensor_utils.h
@@ -18,6 +18,7 @@ limitations under the License.
 #include 
 #include 
 
+#include "third_party/eigen3/Eigen/Core"
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
 
@@ -429,9 +430,10 @@ inline void ApplyRelu6ToVector(const float* __restrict__ vector, int v_size,
 // Apply tanh to elements of a vector
 inline void ApplyTanhToVector(const float* __restrict__ vector, int v_size,
                               float* __restrict__ result) {
-  for (int v = 0; v < v_size; v++) {
-    result[v] = std::tanh(vector[v]);
-  }
+  using VectorMap = Eigen::Map>;
+  VectorMap input_map(const_cast(vector), v_size);
+  VectorMap output_map(result, v_size);
+  output_map.array() = input_map.array().tanh();
 }
 
 // Apply signbit to elements of a vector
@@ -445,9 +447,10 @@ inline void ApplySignbitToVector(const float* __restrict__ vector, int v_size,
 // Apply sigmoid to elements of a vector.
 inline void ApplySigmoidToVector(const float* __restrict__ vector, int v_size,
                                  float* __restrict__ result) {
-  for (int v = 0; v < v_size; v++) {
-    result[v] = 1.0f / (1.0f + std::exp(-vector[v]));
-  }
+  using VectorMap = Eigen::Map>;
+  VectorMap input_map(const_cast(vector), v_size);
+  VectorMap output_map(result, v_size);
+  output_map.array() = input_map.array().logistic();
 }
 
 // Apply appropriate activation function to elements of a vector.
diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc
index 6773b691cfd..ba631a6ee24 100644
--- a/tensorflow/lite/kernels/lstm_eval.cc
+++ b/tensorflow/lite/kernels/lstm_eval.cc
@@ -24,7 +24,6 @@ limitations under the License.
 #include "profiling/profiler.h"
 #endif
 
-#include "third_party/eigen3/Eigen/Core"
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
@@ -363,28 +362,6 @@ inline void LstmStepWithAuxInput(
   }
 }
 
-void ApplyActivationsToVector(float* input, int input_size,
-                              TfLiteFusedActivation activation_type,
-                              float* output) {
-  using VectorMap = Eigen::Map>;
-  VectorMap input_map(input, input_size, 1);
-  VectorMap output_map(output, input_size, 1);
-  switch (activation_type) {
-    case kTfLiteActSigmoid: {
-      output_map.array() = input_map.array().logistic();
-      break;
-    }
-    case kTfLiteActTanh: {
-      output_map.array() = input_map.array().tanh();
-      break;
-    }
-    default: {
-      tensor_utils::ApplyActivationToVector(input, input_size, activation_type,
-                                            output);
-    }
-  }
-}
-
 // Same as above but with quantized weight matrices. In detail:
 // Input of size 'n_batch * n_input':
 //   input_ptr_batch
@@ -699,8 +676,8 @@ inline void LstmStepWithAuxInput(
       tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
                                          input_gate_scratch);
     }
-    ApplyActivationsToVector(input_gate_scratch, n_cell * n_batch,
-                             kTfLiteActSigmoid, input_gate_scratch);
+    tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+                                       input_gate_scratch);
   }
 
   // For each batch and cell: update forget gate.
@@ -721,8 +698,8 @@ inline void LstmStepWithAuxInput(
     tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
                                        forget_gate_scratch);
   }
-  ApplyActivationsToVector(forget_gate_scratch, n_cell * n_batch,
-                           kTfLiteActSigmoid, forget_gate_scratch);
+  tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+                                     forget_gate_scratch);
 
   // For each batch and cell: update the cell.
   tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
@@ -736,8 +713,8 @@ inline void LstmStepWithAuxInput(
     tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
                                        cell_scratch);
   }
-  ApplyActivationsToVector(cell_scratch, n_batch * n_cell, params->activation,
-                           cell_scratch);
+  tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+                                        params->activation, cell_scratch);
   if (use_cifg) {
     tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
                              forget_gate_scratch);
@@ -772,10 +749,10 @@ inline void LstmStepWithAuxInput(
     tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
                                        output_gate_scratch);
   }
-  ApplyActivationsToVector(output_gate_scratch, n_batch * n_cell,
-                           kTfLiteActSigmoid, output_gate_scratch);
-  ApplyActivationsToVector(cell_state_ptr, n_batch * n_cell, params->activation,
-                           cell_scratch);
+  tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+                                     output_gate_scratch);
+  tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+                                        params->activation, cell_scratch);
   tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
                                          n_batch * n_cell, output_gate_scratch);
 

From f15b0dbe034e6b2ac655844c4a502954ed741bd5 Mon Sep 17 00:00:00 2001
From: Gaurav Jain 
Date: Wed, 27 Nov 2019 17:19:09 -0800
Subject: [PATCH 081/279] Merge the TensorHandle mutexes into one

PiperOrigin-RevId: 282858248
Change-Id: I61a2115d45bf64b8e0fd45eb93099dddf42a4430
---
 .../core/common_runtime/eager/tensor_handle.cc     | 14 +++++++-------
 .../core/common_runtime/eager/tensor_handle.h      | 10 +++++-----
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index e4b297b646f..fe8337ce1fc 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -401,7 +401,7 @@ Status TensorHandle::NumElements(int64* num_elements) {
 Status TensorHandle::RemoteAddress(Device* d, int64* op_id,
                                    int32* output_num) const {
   if (d != device_) {
-    tf_shared_lock l(remote_mirrors_mutex_);
+    tf_shared_lock l(mu_);
     auto mirror = remote_mirrors_.find(d);
     if (mirror != remote_mirrors_.end()) {
       *op_id = mirror->second->op_id();
@@ -439,7 +439,7 @@ void TensorHandle::SetRemoteOpIdAndOutputNumToLocalTensorHandle(
 }
 
 bool TensorHandle::HasRemoteMirror(Device* d) {
-  tf_shared_lock l(remote_mirrors_mutex_);
+  tf_shared_lock l(mu_);
   auto mirror = remote_mirrors_.find(d);
   if (mirror != remote_mirrors_.end()) {
     return true;
@@ -454,7 +454,7 @@ bool TensorHandle::HasRemoteMirror(Device* d) {
 }
 
 bool TensorHandle::HasResourceShapeMirror(Device* d) {
-  tf_shared_lock l(resource_shape_mirrors_mutex_);
+  tf_shared_lock l(mu_);
   auto mirror = resource_shape_mirrors_.find(d);
   if (mirror != resource_shape_mirrors_.end()) {
     return true;
@@ -464,7 +464,7 @@ bool TensorHandle::HasResourceShapeMirror(Device* d) {
 
 Status TensorHandle::AddUnshapedRemoteMirror(
     std::unique_ptr t, Device* d) {
-  mutex_lock l(remote_mirrors_mutex_);
+  mutex_lock l(mu_);
   if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
     return errors::Internal("Attempted to duplicate a remote mirror.");
   }
@@ -480,7 +480,7 @@ Status TensorHandle::AddUnshapedRemoteMirror(
 
 Status TensorHandle::AddResourceShapeMirror(
     std::unique_ptr t, Device* d) {
-  mutex_lock l(resource_shape_mirrors_mutex_);
+  mutex_lock l(mu_);
   auto ret = resource_shape_mirrors_.insert(std::make_pair(d, std::move(t)));
   if (!ret.second) {
     return errors::Internal("Attempted to duplicate a resource shape mirror.");
@@ -491,7 +491,7 @@ Status TensorHandle::AddResourceShapeMirror(
 
 Status TensorHandle::AddRemoteMirror(std::unique_ptr t,
                                      Device* d) {
-  mutex_lock l(remote_mirrors_mutex_);
+  mutex_lock l(mu_);
   auto ret = remote_mirrors_.insert(std::make_pair(d, std::move(t)));
   if (!ret.second) {
     return errors::Internal("Attempted to duplicate a remote mirror.");
@@ -505,7 +505,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
   DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d;
 
   if (d != device_) {
-    mutex_lock l(remote_mirrors_mutex_);
+    mutex_lock l(mu_);
     if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
       return errors::Internal(
           "Attempted to set remote shape for existing mirror.");
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index a8b05e34b43..b0393addd4d 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -233,23 +233,23 @@ class TensorHandle : public core::RefCounted {
   tensorflow::Device* const resource_device_;
 
 #if !defined(IS_MOBILE_PLATFORM)
+  mutable mutex mu_;
+
   // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
   // variable is ready, since we could get the shape locally without remote copy
   // then.
-  mutable mutex resource_shape_mirrors_mutex_;
   std::map>
-      resource_shape_mirrors_ GUARDED_BY(resource_shape_mirrors_mutex_);
+      resource_shape_mirrors_ GUARDED_BY(mu_);
 
-  mutable mutex remote_mirrors_mutex_;
   // TODO(gjn): Unshaped remote mirrors are long expected to be long-lived.
   // Consider replacing the unshaped_remote_mirrors_ map with something more
   // efficient.
   std::map>
-      unshaped_remote_mirrors_ GUARDED_BY(remote_mirrors_mutex_);
+      unshaped_remote_mirrors_ GUARDED_BY(mu_);
   // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
   // a fixed size map.
   std::map>
-      remote_mirrors_ GUARDED_BY(remote_mirrors_mutex_);
+      remote_mirrors_ GUARDED_BY(mu_);
 
   // IDs required when this class is representing a remote tensor handle.
   int64 remote_op_id_;

From 9c2db02c90897354a23adad4a824ae520c511577 Mon Sep 17 00:00:00 2001
From: Gaurav Jain 
Date: Wed, 27 Nov 2019 18:12:02 -0800
Subject: [PATCH 082/279] Optimized SetTensor to use std::move

PiperOrigin-RevId: 282863249
Change-Id: I9013084c63b27a210aec1b1d0a10c5da6e6d8e73
---
 tensorflow/core/common_runtime/eager/copy_to_device_node.h    | 2 +-
 tensorflow/core/common_runtime/eager/execute.cc               | 2 +-
 tensorflow/core/common_runtime/eager/tensor_handle.cc         | 2 +-
 tensorflow/core/common_runtime/eager/tensor_handle.h          | 2 +-
 tensorflow/core/distributed_runtime/eager/remote_copy_node.cc | 2 +-
 5 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h
index 144184fac9a..53f3ff94d78 100644
--- a/tensorflow/core/common_runtime/eager/copy_to_device_node.h
+++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h
@@ -43,7 +43,7 @@ class CopyToDeviceNode : public EagerNode {
     MEMDEBUG_CACHE_OP(MEMDEBUG_CACHE_VAL ? MEMDEBUG_CACHE_VAL
                                          : "eager::CopyToDeviceNode");
     TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &tensor));
-    return dst_->SetTensor(tensor);
+    return dst_->SetTensor(std::move(tensor));
   }
 
   void Abort(Status status) override { dst_->Poison(status); }
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index e2c424e8ed6..6ab90a0b940 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -1089,7 +1089,7 @@ Status EagerKernelExecute(
     DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
               retvals[i]->device());
 
-    TF_RETURN_IF_ERROR(retvals[i]->SetTensor(outputs[i]));
+    TF_RETURN_IF_ERROR(retvals[i]->SetTensor(std::move(outputs[i])));
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index fe8337ce1fc..a40686c457f 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -545,7 +545,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
 }
 #endif
 
-Status TensorHandle::SetTensor(const tensorflow::Tensor& tensor) {
+Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
   DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
   DCHECK(!is_ready_notification_.HasBeenNotified())
       << "SetTensor is only called on non-ready handles.";
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index b0393addd4d..7372885ed74 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -158,7 +158,7 @@ class TensorHandle : public core::RefCounted {
   // Sets the `tensor` for this async non-ready handle making it ready.
   // This method or Poison must be called exactly once for non-ready async
   // handles to make them ready.
-  Status SetTensor(const tensorflow::Tensor& tensor);
+  Status SetTensor(tensorflow::Tensor&& tensor);
 
   // Poisons this non-ready handle with an error `status`.
   // Poisoning means that the handle will become ready and methods trying
diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
index 55b16b2587f..0dfcd82d737 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
@@ -277,7 +277,7 @@ void RemoteCopyNode::StartRecv(StatusCallback done) {
       done(status);
       return;
     }
-    status = captured_state_->dst()->SetTensor(outputs[0]);
+    status = captured_state_->dst()->SetTensor(std::move(outputs[0]));
     done(status);
   } else {
     // Handles captured_state_->dst_ internally.

From c4765b1d2771300320844673c37764c31b874680 Mon Sep 17 00:00:00 2001
From: Shanqing Cai 
Date: Wed, 27 Nov 2019 19:55:10 -0800
Subject: [PATCH 083/279] [tfdbg] Record graph context hierarchy in DebugEvent
 file.

- Write `DebuggedGraph` proto to the DebugEvent file with suffix
  ".graphs" during graph creation. This keeps track of the hierarchy
  of contexts ("graphlets" such as cond and body of TF while loops)

Also in this CL:
- Remove the deprecated "U" mode for file reading in source_utils.py
  in order to remove the warning.

PiperOrigin-RevId: 282871417
Change-Id: I2d588d711f77159573288f8f9a2fa40d4a4911c5
---
 tensorflow/core/protobuf/debug_event.proto    |  6 +-
 .../debug/lib/distributed_callbacks_test.py   |  4 +-
 .../python/debug/lib/dumping_callback.py      | 32 +++++++-
 .../python/debug/lib/dumping_callback_test.py | 69 +++++++++++++++---
 .../debug/lib/dumping_callback_test_lib.py    | 73 ++++++++++++++-----
 tensorflow/python/debug/lib/source_utils.py   |  2 +-
 6 files changed, 151 insertions(+), 35 deletions(-)

diff --git a/tensorflow/core/protobuf/debug_event.proto b/tensorflow/core/protobuf/debug_event.proto
index 06499c2406c..8f9680f38d9 100644
--- a/tensorflow/core/protobuf/debug_event.proto
+++ b/tensorflow/core/protobuf/debug_event.proto
@@ -87,8 +87,7 @@ message DebugEvent {
     // a Python function).
     GraphOpCreation graph_op_creation = 7;
 
-    // Information about a debugged graph, including its graph def and
-    // list of the graph's ops that are instrumented.
+    // Information about a debugged graph.
     DebuggedGraph debugged_graph = 8;
 
     // Execution of an op or a Graph (e.g., a tf.function).
@@ -200,6 +199,9 @@ message DebuggedGraph {
   // An encoded version of a GraphDef.
   // This graph may include the debugger-inserted ops.
   bytes instrumented_graph_def = 5;
+
+  // IDs of the immediate enclosing context (graph), if any.
+  string outer_context_id = 6;
 }
 
 // Data relating to the eager execution of an op or a Graph.
diff --git a/tensorflow/python/debug/lib/distributed_callbacks_test.py b/tensorflow/python/debug/lib/distributed_callbacks_test.py
index bd9d908fd36..e1ff0f823c3 100644
--- a/tensorflow/python/debug/lib/distributed_callbacks_test.py
+++ b/tensorflow/python/debug/lib/distributed_callbacks_test.py
@@ -178,7 +178,7 @@ class DistributedDumpingCallbackTest(
 
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
     (context_ids, _,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
     (op_names, device_names, _,
      tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
     executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
@@ -261,7 +261,7 @@ class DistributedDumpingCallbackTest(
 
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
     (context_ids, _,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
     (op_names, device_names, _,
      tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
 
diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py
index 96427536e1d..adb924aefaa 100644
--- a/tensorflow/python/debug/lib/dumping_callback.py
+++ b/tensorflow/python/debug/lib/dumping_callback.py
@@ -119,6 +119,8 @@ class _DumpingCallback(object):
     """Get a unique ID for an op-construction context (e.g., a graph).
 
     If the graph has been encountered before, reuse the same unique ID.
+    When encountering a new context (graph), this methods writes a DebugEvent
+    proto with the debugged_graph field to the proper DebugEvent file.
 
     Args:
       context: A context to get the unique ID for. Must be hashable. E.g., a
@@ -130,10 +132,34 @@ class _DumpingCallback(object):
     # Use the double-checked lock pattern to optimize the common case.
     if context in self._context_to_id:  # 1st check, without lock.
       return self._context_to_id[context]
+    graph_is_new = False
     with self._context_to_id_lock:
       if context not in self._context_to_id:  # 2nd check, with lock.
-        self._context_to_id[context] = _get_id()
-      return self._context_to_id[context]
+        graph_is_new = True
+        context_id = _get_id()
+        self._context_to_id[context] = context_id
+    if graph_is_new:
+      self.get_writer().WriteDebuggedGraph(debug_event_pb2.DebuggedGraph(
+          graph_id=context_id,
+          graph_name=getattr(context, "name", None),
+          outer_context_id=self._get_outer_context_id(context)))
+    return self._context_to_id[context]
+
+  def _get_outer_context_id(self, graph):
+    """Get the ID of the immediate outer context of the input graph.
+
+    Args:
+      graph: The graph (context) in question.
+
+    Returns:
+      If an outer context exists, the immediate outer context name as a string.
+      If such as outer context does not exist (i.e., `graph` is itself
+      outermost), `None`.
+    """
+    if hasattr(graph, "outer_graph") and graph.outer_graph:
+      return self._get_context_id(graph.outer_graph)
+    else:
+      return None
 
   def _write_source_file_content(self, file_path):
     """Send the content of a source file via debug-events writer.
@@ -352,7 +378,7 @@ class _DumpingCallback(object):
 
     writer = self.get_writer()
     if graph:
-      context_id = self._get_context_id(graph)
+      context_id = self._get_context_id(graph)  # Innermost context ID.
       assert op_name is not None
       output_tensor_ids = self._get_symbolic_tensor_ids(len(outputs))
       graph_op_creation = debug_event_pb2.GraphOpCreation(
diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py
index 8cc0242c062..ed222585454 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test.py
@@ -225,7 +225,7 @@ class TracingCallbackTest(
 
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
     (context_ids, op_types,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
     self.assertIn("AddV2", op_types)
     self.assertIn("Log", op_types)
     self.assertIn("Sin", op_types)
@@ -276,7 +276,7 @@ class TracingCallbackTest(
 
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
     (context_ids, op_types,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
     self.assertIn("AddV2", op_types)
     self.assertIn("Log", op_types)
     self.assertIn("Sin", op_types)
@@ -354,7 +354,7 @@ class TracingCallbackTest(
     writer.FlushExecutionFiles()
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
     (context_ids, _,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
     (op_names, _, _,
      tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
     executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
@@ -417,7 +417,7 @@ class TracingCallbackTest(
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
 
     # Verify the content of the .graphs file.
-    context_ids, op_types, op_name_to_op_type = (
+    context_ids, op_types, op_name_to_op_type, _ = (
         self._readAndCheckGraphsFile(stack_frame_by_id))
     self.assertIn("Less", op_types)
     self.assertIn("Mul", op_types)
@@ -555,7 +555,7 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    context_ids, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id)
+    context_ids, _, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id)
     _, _, _, _, tensor_values = self._readAndCheckExecutionFile()
     self.assertEqual(tensor_values, [[]])
     (_, _, _,
@@ -638,7 +638,7 @@ class TracingCallbackTest(
       prev_wall_time = debug_event.wall_time
 
     (context_ids, _,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
 
     (op_names, _, output_slots,
      tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
@@ -718,6 +718,57 @@ class TracingCallbackTest(
     v2_squared_values = tensor_values[executed_op_types.index("Pow")]
     self.assertAllClose(v2_squared_values, [9.0])
 
+  @test_util.run_in_graph_and_eager_modes
+  def testNestedContextIsCapturedByGraphOpCreationHistory(self):
+    writer = dumping_callback.enable_dump_debug_info(
+        self.dump_root, tensor_debug_mode="NO_TENSOR")
+
+    @def_function.function
+    def iterative_doubling(x, times):
+      i = constant_op.constant(0, dtype=dtypes.int32)
+      while i < times:
+        x = x * 2.0 - 1.0
+        i += 1
+      return x
+
+    x = constant_op.constant(2.0, dtype=dtypes.float32)
+    times = constant_op.constant(4, dtype=dtypes.int32)
+    # 2 * 2 - 1 = 3; 3 * 2 - 1 = 5; 5 * 2 - 1 = 9; 9 * 2 - 1 = 17.
+    self.assertAllClose(self.evaluate(iterative_doubling(x, times)), 17.0)
+
+    writer.FlushNonExecutionFiles()
+    writer.FlushExecutionFiles()
+
+    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
+    (_, _, op_name_to_op_type,
+     op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id)
+
+    less_op_names = [op_name for op_name in op_name_to_op_type
+                     if op_name_to_op_type[op_name] == "Less"]
+    less_context_ids = [op_name_to_context_id[op_name]
+                        for op_name in less_op_names]
+    mul_op_names = [op_name for op_name in op_name_to_op_type
+                    if op_name_to_op_type[op_name] == "Mul"]
+    mul_context_ids = [op_name_to_context_id[op_name]
+                       for op_name in mul_op_names]
+    sub_op_names = [op_name for op_name in op_name_to_op_type
+                    if op_name_to_op_type[op_name] == "Sub"]
+    sub_context_ids = [op_name_to_context_id[op_name]
+                       for op_name in sub_op_names]
+    self.assertLen(less_context_ids, 1)
+    self.assertLen(mul_context_ids, 1)
+    self.assertLen(sub_context_ids, 1)
+    self.assertTrue(less_context_ids[0])
+    self.assertTrue(mul_context_ids[0])
+    self.assertTrue(sub_context_ids[0])
+    # The Less op is from the while-loop cond context and hence should have
+    # a different innermost context ID from the mul and sub ops, which are both
+    # from the while-loop body context.
+    self.assertNotEqual(less_context_ids[0], mul_context_ids[0])
+    self.assertNotEqual(less_context_ids[0], sub_context_ids[0])
+    # The Mul and Sub ops are from the same innermost context.
+    self.assertEqual(mul_context_ids[0], sub_context_ids[0])
+
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
       ("FullTensor", "FULL_TENSOR"),
@@ -736,7 +787,7 @@ class TracingCallbackTest(
 
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
     (context_ids, op_types,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
     # Simply assert that graph are recorded and refrain from asserting on the
     # internal details of the Keras model.
     self.assertTrue(context_ids)
@@ -803,7 +854,7 @@ class TracingCallbackTest(
 
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
     (context_ids, op_types,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
     # Simply assert that graph are recorded and refrain from asserting on the
     # internal details of the Keras model.
     self.assertTrue(context_ids)
@@ -876,7 +927,7 @@ class TracingCallbackTest(
 
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
     (context_ids, op_types,
-     op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
+     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
     # Simply assert that graph are recorded and refrain from asserting on the
     # internal details of the Keras model.
     self.assertTrue(context_ids)
diff --git a/tensorflow/python/debug/lib/dumping_callback_test_lib.py b/tensorflow/python/debug/lib/dumping_callback_test_lib.py
index 2169ab9ce2b..e572c48d04c 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test_lib.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test_lib.py
@@ -116,35 +116,72 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
       op_types: Types of the ops that are created, as a `list` of `str`s with
         the same length as `context_ids`.
       op_name_to_op_type: A `dict` mapping op name to op type.
+      op_name_to_context_id: A `dict` mapping op name to the ID of the innermost
+        containing graph (context).
     """
     reader = debug_events_reader.DebugEventsReader(self.dump_root)
     graphs_iter = reader.graphs_iterator()
     prev_wall_time = 0
     op_types = []
     op_name_to_op_type = dict()
+    op_name_to_context_id = dict()  # Maps op name to ID of innermost context.
     context_ids = set()
     symbolic_tensor_ids = set()
+    # Maps context ID to ID of directly enclosing context (`None` for
+    # outermost contexts).
+    context_id_to_outer_id = dict()
+
     for debug_event in graphs_iter:
       self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
       prev_wall_time = debug_event.wall_time
-      graph_op_creation = debug_event.graph_op_creation
-      self.assertTrue(graph_op_creation.op_type)
-      op_types.append(graph_op_creation.op_type)
-      self.assertTrue(graph_op_creation.op_name)
-      op_name_to_op_type[graph_op_creation.op_name] = graph_op_creation.op_type
-      self.assertTrue(graph_op_creation.graph_id)
-      context_ids.add(graph_op_creation.graph_id)
-      self.assertTrue(graph_op_creation.code_location)
-      if graph_op_creation.num_outputs:
-        self.assertLen(graph_op_creation.output_tensor_ids,
-                       graph_op_creation.num_outputs)
-        # Check that all symblic tensor IDs are unique.
-        for tensor_id in graph_op_creation.output_tensor_ids:
-          self.assertNotIn(tensor_id, symbolic_tensor_ids)
-          symbolic_tensor_ids.add(tensor_id)
-      for stack_frame_id in graph_op_creation.code_location.stack_frame_ids:
-        self.assertIn(stack_frame_id, stack_frame_by_id)
-    return context_ids, op_types, op_name_to_op_type
+      # A DebugEvent in the .graphs file contains either of the two fields:
+      # - graph_op_creation for creation of a symbolic op in a graph context.
+      # - debugged_graph for information regarding the graph (context).
+      if debug_event.graph_op_creation.ByteSize():
+        graph_op_creation = debug_event.graph_op_creation
+        self.assertTrue(graph_op_creation.op_type)
+        op_types.append(graph_op_creation.op_type)
+        self.assertTrue(graph_op_creation.op_name)
+        op_name_to_op_type[
+            graph_op_creation.op_name] = graph_op_creation.op_type
+        op_name_to_context_id[
+            graph_op_creation.op_name] = graph_op_creation.graph_id
+        self.assertTrue(graph_op_creation.graph_id)
+        context_ids.add(graph_op_creation.graph_id)
+        self.assertTrue(graph_op_creation.code_location)
+        if graph_op_creation.num_outputs:
+          self.assertLen(graph_op_creation.output_tensor_ids,
+                         graph_op_creation.num_outputs)
+          # Check that all symblic tensor IDs are unique.
+          for tensor_id in graph_op_creation.output_tensor_ids:
+            self.assertNotIn(tensor_id, symbolic_tensor_ids)
+            symbolic_tensor_ids.add(tensor_id)
+        for stack_frame_id in graph_op_creation.code_location.stack_frame_ids:
+          self.assertIn(stack_frame_id, stack_frame_by_id)
+      else:
+        debugged_graph = debug_event.debugged_graph
+        if debugged_graph.outer_context_id:
+          inner_id = debugged_graph.graph_id
+          outer_id = debugged_graph.outer_context_id
+          if inner_id in context_id_to_outer_id:
+            # The outer context of a context must be always the same.
+            self.assertEqual(context_id_to_outer_id[inner_id], outer_id)
+          else:
+            context_id_to_outer_id[inner_id] = outer_id
+        else:
+          # This is an outermost context.
+          if debugged_graph.graph_id in context_id_to_outer_id:
+            self.assertIsNone(context_id_to_outer_id[debugged_graph.graph_id])
+          else:
+            context_id_to_outer_id[debugged_graph.graph_id] = None
+
+    # If any graph is created, the graph context hierarchy must be populated.
+    # In addition, the context of each graph op must be locatable within the
+    # graph context hierarchy.
+    for context_id in op_name_to_context_id.values():
+      self.assertIn(context_id, context_id_to_outer_id)
+
+    return context_ids, op_types, op_name_to_op_type, op_name_to_context_id
 
   def _readAndCheckExecutionFile(self, dump_root=None):
     """Read and verify the content of the .execution debug-event file.
diff --git a/tensorflow/python/debug/lib/source_utils.py b/tensorflow/python/debug/lib/source_utils.py
index 9c8f7733ef9..2a59ab97fc7 100644
--- a/tensorflow/python/debug/lib/source_utils.py
+++ b/tensorflow/python/debug/lib/source_utils.py
@@ -88,7 +88,7 @@ def guess_is_tensorflow_py_library(py_file_path):
 
 
 def load_source(source_file_path):
-  with open(source_file_path, "rU") as f:
+  with open(source_file_path, "r") as f:
     source_text = f.read()
   source_lines = source_text.split("\n")
   line_num_width = int(np.ceil(np.log10(len(source_lines)))) + 3

From 33877af57f9998369baee28ca5cde2d69b6f943d Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Wed, 27 Nov 2019 21:40:35 -0800
Subject: [PATCH 084/279] Remove dependence on core/lib/strings

PiperOrigin-RevId: 282881315
Change-Id: I3d6d18ec1dbf068f9d2d69c605e343994e9d986f
---
 tensorflow/core/platform/cloud/BUILD                       | 2 ++
 tensorflow/core/platform/cloud/gcs_file_system.cc          | 2 +-
 .../profile_utils/android_armv7a_cpu_utils_helper.cc       | 2 +-
 tensorflow/core/platform/s3/BUILD                          | 1 +
 tensorflow/core/platform/s3/aws_logging.cc                 | 7 ++++---
 5 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index d578b1a2388..8bd431467e1 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -99,6 +99,7 @@ cc_library(
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
     ],
     alwayslink = 1,
@@ -131,6 +132,7 @@ cc_library(
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
     ],
     alwayslink = 1,
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 55bbec7cb88..c59b71158da 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -32,7 +32,6 @@ limitations under the License.
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/cloud/file_block_cache.h"
 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
@@ -42,6 +41,7 @@ limitations under the License.
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/stringprintf.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 
 #ifdef _WIN32
diff --git a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.cc b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.cc
index 12dc9c58b38..0534443d17c 100644
--- a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.cc
+++ b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.cc
@@ -28,8 +28,8 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stringprintf.h"
 
 namespace tensorflow {
 namespace profile_utils {
diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD
index d518bfb71a2..d0c68baa860 100644
--- a/tensorflow/core/platform/s3/BUILD
+++ b/tensorflow/core/platform/s3/BUILD
@@ -65,6 +65,7 @@ cc_library(
     deps = [
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:stringprintf",
         "@aws",
     ],
     alwayslink = 1,
diff --git a/tensorflow/core/platform/s3/aws_logging.cc b/tensorflow/core/platform/s3/aws_logging.cc
index dac56908893..1d549a2a61e 100644
--- a/tensorflow/core/platform/s3/aws_logging.cc
+++ b/tensorflow/core/platform/s3/aws_logging.cc
@@ -13,9 +13,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/platform/s3/aws_logging.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
 
 #include 
 #include 
@@ -23,6 +20,10 @@ limitations under the License.
 
 #include 
 
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stringprintf.h"
+
 namespace tensorflow {
 
 AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level)

From 64ba82b7d327b2f2be8299d879c6403be5453e37 Mon Sep 17 00:00:00 2001
From: Derek Murray 
Date: Wed, 27 Nov 2019 21:47:07 -0800
Subject: [PATCH 085/279] Rolling forward "Add PrivateIntraProcessRendezvous."
 with a fix.

For now, when using ExecutorBarrier with multiple devices, we continue to use a RefCountedIntraProcessRendezvous.

PiperOrigin-RevId: 282882166
Change-Id: I775818ba60db43f34745fb221e63e4d6ca065121
---
 .../core/common_runtime/direct_session.cc     | 13 ++-
 .../core/common_runtime/direct_session.h      |  4 +-
 tensorflow/core/common_runtime/function.cc    | 12 +--
 .../core/common_runtime/function_test.cc      |  5 +-
 .../process_function_library_runtime_test.cc  | 39 ++++----
 .../core/common_runtime/rendezvous_mgr.cc     | 97 +++++++++++++------
 .../core/common_runtime/rendezvous_mgr.h      | 67 +++++++------
 7 files changed, 144 insertions(+), 93 deletions(-)

diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 133a6c31a93..c836cb23898 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -521,7 +521,6 @@ Status DirectSession::RunInternal(
                             executor_step_count, &debugger_state));
   }
 
-  run_state.rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
 #ifndef __ANDROID__
   // Set up for collectives if ExecutorsAndKeys declares a key.
   if (executors_and_keys->collective_graph_key !=
@@ -616,7 +615,6 @@ Status DirectSession::RunInternal(
   Executor::Args args;
   args.step_id = step_id;
   args.call_frame = call_frame;
-  args.rendezvous = run_state.rendez.get();
   args.collective_executor =
       (run_state.collective_executor ? run_state.collective_executor->get()
                                      : nullptr);
@@ -688,14 +686,21 @@ Status DirectSession::RunInternal(
       };
 
   if (can_execute_synchronously) {
+    PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
+    args.rendezvous = &rendezvous;
+
     const auto& item = executors_and_keys->items[0];
     set_threadpool_args_for_item(item, &args);
     run_status = item.executor->Run(args);
   } else {
+    core::RefCountPtr rendezvous(
+        new RefCountedIntraProcessRendezvous(device_mgr_.get()));
+    args.rendezvous = rendezvous.get();
+
     // `barrier` will delete itself after the final executor finishes.
     Notification executors_done;
     ExecutorBarrier* barrier =
-        new ExecutorBarrier(num_executors, run_state.rendez.get(),
+        new ExecutorBarrier(num_executors, rendezvous.get(),
                             [&run_state, &executors_done](const Status& ret) {
                               {
                                 mutex_lock l(run_state.mu);
@@ -1139,7 +1144,7 @@ Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
 
 Status DirectSession::RecvPRunOutputs(
     const std::vector& output_names,
-    const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
+    const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
     std::vector* outputs) {
   Status s;
   if (!output_names.empty()) {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index a272633b4e2..7bbb198ef44 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -191,7 +191,6 @@ class DirectSession : public Session {
   struct RunState {
     mutex mu;
     Status status GUARDED_BY(mu);
-    core::RefCountPtr rendez = nullptr;
     std::unique_ptr collective_executor;
     std::unique_ptr collector;
     TensorStore tensor_store;
@@ -208,6 +207,7 @@ class DirectSession : public Session {
     Notification executors_done;
     std::unordered_map pending_inputs;   // true if fed
     std::unordered_map pending_outputs;  // true if fetched
+    core::RefCountPtr rendez = nullptr;
 
     PartialRunState(const std::vector& pending_input_names,
                     const std::vector& pending_output_names,
@@ -282,7 +282,7 @@ class DirectSession : public Session {
   // tensors are computed.
   ::tensorflow::Status RecvPRunOutputs(
       const std::vector& output_names,
-      const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
+      const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
       std::vector* outputs);
 
   // Check if the specified fetches can be computed from the feeds
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index aa3be38fd29..501002e1f7f 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1116,11 +1116,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
   }
   Options run_opts = opts;
   if (opts.create_rendezvous) {
-    Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
+    auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
     run_opts.rendezvous = rendezvous;
     run_opts.create_rendezvous = false;
-    done = [done = std::move(done), rendezvous](const Status& status) {
-      rendezvous->Unref();
+    done = [done = std::move(done), rendezvous](const Status& status) mutable {
+      delete rendezvous;
       done(status);
     };
   }
@@ -1187,11 +1187,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
 
   Options run_opts = opts;
   if (opts.create_rendezvous) {
-    Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
+    auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
     run_opts.rendezvous = rendezvous;
     run_opts.create_rendezvous = false;
-    done = [done = std::move(done), rendezvous](const Status& status) {
-      rendezvous->Unref();
+    done = [done = std::move(done), rendezvous](const Status& status) mutable {
+      delete rendezvous;
       done(status);
     };
   }
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 7c76c469d1e..89e4daa50b3 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -1854,8 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
 
   Tensor y;
   FunctionLibraryRuntime::Options opts;
-  Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_.get());
-  opts.rendezvous = rendezvous;
+  PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
+  opts.rendezvous = &rendezvous;
   opts.source_device = "/device:CPU:1";
   // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
   TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
@@ -1870,7 +1870,6 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
       y,
       test::AsTensor({"/job:localhost/replica:0/task:0/device:CPU:1"},
                               TensorShape({})));
-  rendezvous->Unref();
 }
 
 namespace {
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 1a5ed3caa11..55bc408f9c5 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -110,12 +110,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
     }
   }
 
-  ~ProcessFunctionLibraryRuntimeTest() override {
-    if (rendezvous_ != nullptr) {
-      rendezvous_->Unref();
-    }
-  }
-
   void Init(const std::vector& flib,
             const SessionMetadata* session_metadata = nullptr) {
     FunctionDefLibrary proto;
@@ -127,7 +121,8 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
         device_mgr_.get(), Env::Default(), /*config=*/nullptr,
         TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(),
         nullptr, session_metadata));
-    rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
+    rendezvous_ =
+        absl::make_unique(device_mgr_.get());
   }
 
   Status Instantiate(
@@ -263,7 +258,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
           test::function::FunctionTestSchedClosure(fn);
         };
 
-    opts.rendezvous = rendezvous_;
+    opts.rendezvous = rendezvous_.get();
     opts.runner = &runner;
     Status status;
     Notification done;
@@ -292,7 +287,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
   std::unique_ptr lib_def_;
   std::unique_ptr cluster_flr_;
   std::unique_ptr proc_flr_;
-  IntraProcessRendezvous* rendezvous_ = nullptr;
+  std::unique_ptr rendezvous_ = nullptr;
 };
 
 TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
@@ -344,7 +339,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
   Init({test::function::XTimesTwo()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -359,7 +354,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -375,7 +370,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
   auto x = test::AsTensor({1, 2, 3, 4});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -392,7 +387,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:1";
@@ -411,7 +406,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   Tensor y;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts_0;
@@ -432,7 +427,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
@@ -462,7 +457,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
@@ -509,7 +504,7 @@ void TestTwoDeviceMult(
     const string& error = "") {
   fixture->Init({test::function::TwoDeviceMult()});
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = fixture->rendezvous_;
+  opts.rendezvous = fixture->rendezvous_.get();
   auto x = test::AsTensor({1, 2, 3});
   Tensor y_cpu;
   Tensor y_gpu;
@@ -542,7 +537,7 @@ void TestTwoDeviceInputOutput(
   fixture->Init({test::function::TwoDeviceInputOutput()});
 
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = fixture->rendezvous_;
+  opts.rendezvous = fixture->rendezvous_.get();
   Tensor x1 = test::AsTensor({1, 2});
   if (absl::StrContains(inst_opts.input_devices[0], "GPU")) {
     x1 = fixture->CPUToGPU(x1);
@@ -743,7 +738,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
 
   // Run the function taking a resource and outputing it
   FunctionLibraryRuntime::Options opts;
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   Tensor x1 = CPUToGPU(test::AsTensor({1, 2}));
   Tensor x2 = GetResourceHandle("my_gpu_var", mgr->default_container(),
                                 "/job:a/replica:0/task:0/device:GPU:0");
@@ -985,7 +980,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataAbsent) {
   Init({SessionMetadataReaderOpFn()}, /*session_metadata=*/nullptr);
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -1001,7 +996,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
   Init({SessionMetadataReaderOpFn()}, &session_metadata);
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
@@ -1027,7 +1022,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
   TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
   FunctionLibraryRuntime::Options opts;
   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
-  opts.rendezvous = rendezvous_;
+  opts.rendezvous = rendezvous_.get();
   opts.remote_execution = true;
   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc
index 0d5e79667db..6ed7df2cc1e 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.cc
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc
@@ -32,23 +32,12 @@ limitations under the License.
 
 namespace tensorflow {
 
-IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
-    : device_mgr_(device_mgr) {}
-
-IntraProcessRendezvous::~IntraProcessRendezvous() {}
-
-Status IntraProcessRendezvous::Send(const ParsedKey& key,
-                                    const Rendezvous::Args& args,
-                                    const Tensor& val, const bool is_dead) {
-  VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
-  // Buffers "val" and "device_context" in local_.
-  return local_.Send(key, args, val, is_dead);
-}
-
-void IntraProcessRendezvous::SameWorkerRecvDone(
-    const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
-    const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
-    StatusCallback done) {
+namespace {
+void SameWorkerRecvDone(const DeviceMgr* device_mgr,
+                        const Rendezvous::ParsedKey& parsed,
+                        const Rendezvous::Args& send_args,
+                        const Rendezvous::Args& recv_args, const Tensor& in,
+                        Tensor* out, StatusCallback done) {
   // Do a quick copy (sharing the underlying buffer) if both tensors
   // are on host memory.
   const bool src_host =
@@ -73,13 +62,13 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
   }
 
   Device* src_device;
-  Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
+  Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
   if (!s.ok()) {
     done(s);
     return;
   }
   Device* dst_device;
-  s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
+  s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
   if (!s.ok()) {
     done(s);
     return;
@@ -116,16 +105,18 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
 }
 
-void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
-                                       const Rendezvous::Args& args,
-                                       DoneCallback done) {
-  VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
+void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
+                               LocalRendezvous* local,
+                               const RendezvousInterface::ParsedKey& parsed,
+                               const Rendezvous::Args& recv_args,
+                               RendezvousInterface::DoneCallback done) {
+  VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
 
   MEMDEBUG_CACHE_OP("RecvAsync");
   // Recv the tensor from local_.
-  local_.RecvAsync(
-      key, args,
-      [this, key, done = std::move(done)](
+  local->RecvAsync(
+      parsed, recv_args,
+      [device_mgr, parsed, done = std::move(done)](
           const Status& status, const Rendezvous::Args& send_args,
           const Rendezvous::Args& recv_args, const Tensor& in,
           bool is_dead) mutable {
@@ -141,7 +132,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
         };
 
         if (status.ok() && in.IsInitialized()) {
-          SameWorkerRecvDone(key, send_args, recv_args, in, out,
+          SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
                              std::move(final_callback));
         } else {
           final_callback(status);
@@ -149,8 +140,56 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
       });
 }
 
-void IntraProcessRendezvous::StartAbort(const Status& s) {
-  CHECK(!s.ok());
+}  // namespace
+
+RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
+    const DeviceMgr* device_mgr)
+    : device_mgr_(device_mgr) {}
+
+RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
+
+Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
+                                              const Rendezvous::Args& args,
+                                              const Tensor& val,
+                                              const bool is_dead) {
+  VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
+  return local_.Send(key, args, val, is_dead);
+}
+
+void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
+                                                 const Rendezvous::Args& args,
+                                                 DoneCallback done) {
+  VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
+  IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
+}
+
+void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
+  local_.StartAbort(s);
+}
+
+PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
+    const DeviceMgr* device_mgr)
+    : device_mgr_(device_mgr) {}
+
+PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
+
+Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
+                                           const Rendezvous::Args& args,
+                                           const Tensor& val,
+                                           const bool is_dead) {
+  DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
+  return local_.Send(key, args, val, is_dead);
+}
+
+void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
+                                              const Rendezvous::Args& args,
+                                              DoneCallback done) {
+  DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
+           << key.FullKey();
+  IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
+}
+
+void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
   local_.StartAbort(s);
 }
 
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h
index a9d3de122f0..eea5fbe388c 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.h
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.h
@@ -30,48 +30,61 @@ limitations under the License.
 
 namespace tensorflow {
 
-// IntraProcessRendezvous is a Rendezvous which expects all producers
-// and consumers to be devices immediately accessible within the
-// process. That is, it will never be necessary to perform an RPC to
+// The IntraProcessRendezvous classes are implementations of a Rendezvous that
+// expects all producers and consumers to be devices immediately accessible
+// within the process. That is, it will never be necessary to perform an RPC to
 // communicate with either.
 //
-// Buffering of Tensor values is delegated to a `LocalRendezvous`. This class
-// just adds functionality to coordinate multiple process-local devices.
-class IntraProcessRendezvous : public Rendezvous {
- public:
-  explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
+// Buffering of Tensor values is delegated to a `LocalRendezvous`. An
+// IntraProcessRendezvous. just adds functionality to coordinate multiple
+// process-local devices.
 
-  // Forwards to local_, where the Tensor "val" will be buffered and
-  // any waiting callback stored.
+// Reference-counted implementation that may be shared between multiple threads.
+class RefCountedIntraProcessRendezvous : public Rendezvous {
+ public:
+  explicit RefCountedIntraProcessRendezvous(const DeviceMgr* device_mgr);
+
+  // Implementation of RendezvousInterface methods.
   Status Send(const ParsedKey& key, const Rendezvous::Args& args,
               const Tensor& val, const bool is_dead) override;
-
-  // This method is called only by the RecvOp.  It tests to see
-  // whether the value will be produced by a local or remote device
-  // and handles accordingly.  In the local case it forwards to
-  // local_, in the remote case it initiates an RPC request.
   void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
                  DoneCallback done) override;
-
   void StartAbort(const Status& status) override;
 
  private:
   const DeviceMgr* device_mgr_;
   LocalRendezvous local_;
 
-  ~IntraProcessRendezvous() override;
+  ~RefCountedIntraProcessRendezvous() override;
 
-  // Callback handling the case when a rendezvous has been
-  // accomplished in local_ and the consumer is local to this process.
-  // Tensor "in" will be copied into "out". The key "parsed" encodes
-  // the src and dst devices.
-  typedef std::function StatusCallback;
-  void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
-                          const Rendezvous::Args& send_args,
-                          const Rendezvous::Args& recv_args, const Tensor& in,
-                          Tensor* out, StatusCallback done);
+  TF_DISALLOW_COPY_AND_ASSIGN(RefCountedIntraProcessRendezvous);
+};
 
-  TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
+// RefCountedIntraProcessRendezvous is aliased to IntraProcessRendezvous for
+// backwards compatibility with existing users.
+using IntraProcessRendezvous = RefCountedIntraProcessRendezvous;
+
+// Non-reference-counted implementation that may be stack-allocated for
+// performance.
+//
+// Prefer to use PrivateIntraProcessRendezvous in new code.
+class PrivateIntraProcessRendezvous : public RendezvousInterface {
+ public:
+  explicit PrivateIntraProcessRendezvous(const DeviceMgr* device_mgr);
+  ~PrivateIntraProcessRendezvous() override;
+
+  // Implementation of RendezvousInterface methods.
+  Status Send(const ParsedKey& key, const Rendezvous::Args& args,
+              const Tensor& val, const bool is_dead) override;
+  void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
+                 DoneCallback done) override;
+  void StartAbort(const Status& status) override;
+
+ private:
+  const DeviceMgr* device_mgr_;
+  LocalRendezvous local_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(PrivateIntraProcessRendezvous);
 };
 
 }  // end namespace tensorflow

From 71aa7dda135171d5ee389d0a77f41af3deb0a524 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Wed, 27 Nov 2019 22:12:26 -0800
Subject: [PATCH 086/279] Remove dependence on redirection header
 core/lib/strings/str_util

PiperOrigin-RevId: 282884423
Change-Id: I4c457c6b0cec249ed6e0c9858adb77446b322f34
---
 tensorflow/core/platform/cloud/BUILD                |  9 +++++++++
 .../platform/cloud/compute_engine_zone_provider.cc  |  3 ++-
 tensorflow/core/platform/cloud/curl_http_request.cc |  6 +++---
 .../core/platform/cloud/gcs_dns_cache_test.cc       |  3 ++-
 tensorflow/core/platform/cloud/gcs_file_system.cc   |  2 +-
 .../core/platform/cloud/gcs_file_system_test.cc     |  4 +++-
 tensorflow/core/platform/cloud/gcs_throttle_test.cc |  3 ++-
 .../platform/cloud/retrying_file_system_test.cc     |  4 +++-
 .../core/platform/cloud/retrying_utils_test.cc      |  4 +++-
 tensorflow/core/platform/default/test_benchmark.cc  |  6 +++---
 tensorflow/core/platform/env_test.cc                |  2 +-
 tensorflow/core/platform/file_system_test.cc        |  2 +-
 tensorflow/core/platform/hadoop/BUILD               |  1 +
 .../core/platform/hadoop/hadoop_file_system_test.cc |  2 +-
 tensorflow/core/platform/platform_strings_test.cc   |  5 +++--
 tensorflow/core/platform/s3/BUILD                   |  1 +
 tensorflow/core/platform/s3/s3_file_system.cc       | 13 +++++++------
 17 files changed, 46 insertions(+), 24 deletions(-)

diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 8bd431467e1..e2d1f7b4e37 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -99,6 +99,7 @@ cc_library(
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
     ],
@@ -132,6 +133,7 @@ cc_library(
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
     ],
@@ -158,6 +160,7 @@ cc_library(
         ":http_request",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringpiece",
         "@curl",
     ],
@@ -231,6 +234,7 @@ cc_library(
         ":compute_engine_metadata_client",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:str_util",
     ],
 )
 
@@ -344,6 +348,7 @@ tf_cc_test(
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:str_util",
     ],
 )
 
@@ -357,6 +362,7 @@ tf_cc_test(
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:str_util",
     ],
 )
 
@@ -370,6 +376,7 @@ tf_cc_test(
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:str_util",
     ],
 )
 
@@ -461,6 +468,7 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:str_util",
     ],
 )
 
@@ -485,5 +493,6 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:str_util",
     ],
 )
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
index e147d883710..8008c9cc9ec 100644
--- a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
@@ -16,7 +16,8 @@ limitations under the License.
 #include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h"
 
 #include 
-#include "tensorflow/core/lib/strings/str_util.h"
+
+#include "tensorflow/core/platform/str_util.h"
 namespace tensorflow {
 
 namespace {
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index c64e215ea99..ea007affac7 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include 
-
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 
+#include 
+
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/strings/scanner.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/public/version.h"
 
diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc
index 77850906c6c..09644767152 100644
--- a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc
@@ -14,7 +14,8 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/gcs_dns_cache.h"
-#include "tensorflow/core/lib/strings/str_util.h"
+
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index c59b71158da..c461ab61313 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -31,7 +31,6 @@ limitations under the License.
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/cloud/file_block_cache.h"
 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
@@ -41,6 +40,7 @@ limitations under the License.
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/stringprintf.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 566ad45a43c..387c6391a15 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -14,11 +14,13 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/gcs_file_system.h"
+
 #include 
+
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/cloud/http_request_fake.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/gcs_throttle_test.cc b/tensorflow/core/platform/cloud/gcs_throttle_test.cc
index e8eebc5fbc3..404e922502d 100644
--- a/tensorflow/core/platform/cloud/gcs_throttle_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_throttle_test.cc
@@ -14,8 +14,9 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/gcs_throttle.h"
+
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index 2b26f27f82c..1df371a6080 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -14,9 +14,11 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/retrying_file_system.h"
+
 #include 
+
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/cloud/retrying_utils_test.cc
index 771bb44285e..7a2dbacacc8 100644
--- a/tensorflow/core/platform/cloud/retrying_utils_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils_test.cc
@@ -14,10 +14,12 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/retrying_utils.h"
+
 #include 
+
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc
index dedab42bd73..77747fc5b79 100644
--- a/tensorflow/core/platform/default/test_benchmark.cc
+++ b/tensorflow/core/platform/default/test_benchmark.cc
@@ -15,14 +15,14 @@ limitations under the License.
 
 #include "tensorflow/core/platform/test_benchmark.h"
 
+#include 
 #include 
 #include 
-
-#include 
 #include 
-#include "tensorflow/core/lib/strings/str_util.h"
+
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/util/reporter.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 8298df9a817..06e09b1911d 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -21,11 +21,11 @@ limitations under the License.
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/cord.h"
 #include "tensorflow/core/platform/null_file_system.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/test.h"
 
diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc
index a931634a3c8..78358bd3458 100644
--- a/tensorflow/core/platform/file_system_test.cc
+++ b/tensorflow/core/platform/file_system_test.cc
@@ -19,9 +19,9 @@ limitations under the License.
 
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/null_file_system.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/hadoop/BUILD b/tensorflow/core/platform/hadoop/BUILD
index fc6ae4dc6b7..b68eb0bcd4f 100644
--- a/tensorflow/core/platform/hadoop/BUILD
+++ b/tensorflow/core/platform/hadoop/BUILD
@@ -58,5 +58,6 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:str_util",
     ],
 )
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
index 1242e2547fc..0c21e9662ee 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
@@ -17,8 +17,8 @@ limitations under the License.
 
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/platform_strings_test.cc b/tensorflow/core/platform/platform_strings_test.cc
index a4eb845c25e..3824ff550f3 100644
--- a/tensorflow/core/platform/platform_strings_test.cc
+++ b/tensorflow/core/platform/platform_strings_test.cc
@@ -15,6 +15,8 @@ limitations under the License.
 
 // Test for the platform_strings.h header file.
 
+#include "tensorflow/core/platform/platform_strings.h"
+
 #include 
 #include 
 #include 
@@ -24,11 +26,10 @@ limitations under the License.
 #include 
 
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/init_main.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/platform_strings.h"
+#include "tensorflow/core/platform/str_util.h"
 
 // Embed the platform strings in this binary.
 TF_PLATFORM_STRINGS()
diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD
index d0c68baa860..d22a759c440 100644
--- a/tensorflow/core/platform/s3/BUILD
+++ b/tensorflow/core/platform/s3/BUILD
@@ -84,6 +84,7 @@ cc_library(
         ":aws_logging",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:str_util",
         "@aws",
     ],
     alwayslink = 1,
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index d32158f70bd..8c821faa651 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -13,12 +13,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/platform/s3/s3_file_system.h"
-#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/file_system_helper.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/s3/aws_crypto.h"
-#include "tensorflow/core/platform/s3/aws_logging.h"
 
 #include 
 #include 
@@ -38,6 +32,13 @@ limitations under the License.
 
 #include 
 
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/file_system_helper.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/s3/aws_crypto.h"
+#include "tensorflow/core/platform/s3/aws_logging.h"
+#include "tensorflow/core/platform/str_util.h"
+
 namespace tensorflow {
 
 namespace {

From 4ffdcd9755d77dfaa6cd74ac0cc53ddc0cb5bd4b Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Wed, 27 Nov 2019 22:52:14 -0800
Subject: [PATCH 087/279] Remove dependence on lib/core:errors in
 tensorflow/core/platform

Also minor formatting fixes.

PiperOrigin-RevId: 282887344
Change-Id: I38a5cb1c48d2cbee14c5f702e3ad2ead23f59e2b
---
 tensorflow/core/platform/BUILD                         |  2 +-
 tensorflow/core/platform/cloud/BUILD                   |  9 +++++++++
 tensorflow/core/platform/cloud/auth_provider.h         |  2 +-
 tensorflow/core/platform/cloud/curl_http_request.cc    |  2 ++
 tensorflow/core/platform/cloud/curl_http_request.h     |  2 +-
 tensorflow/core/platform/cloud/gcs_file_system.cc      |  2 +-
 tensorflow/core/platform/cloud/gcs_file_system_test.cc |  1 +
 tensorflow/core/platform/cloud/google_auth_provider.cc |  3 ++-
 tensorflow/core/platform/cloud/http_request.h          |  2 +-
 tensorflow/core/platform/cloud/http_request_fake.h     |  2 +-
 tensorflow/core/platform/cloud/oauth_client.cc         |  3 ++-
 tensorflow/core/platform/cloud/retrying_file_system.h  |  2 +-
 tensorflow/core/platform/cloud/retrying_utils.cc       |  3 ++-
 tensorflow/core/platform/cloud/time_util.cc            |  2 +-
 tensorflow/core/platform/cloud/zone_provider.h         |  2 +-
 tensorflow/core/platform/default/build_refactor.bzl    | 10 +++++-----
 .../core/platform/default/human_readable_json.cc       |  2 +-
 tensorflow/core/platform/default/load_library.cc       |  2 +-
 tensorflow/core/platform/env.cc                        |  2 +-
 tensorflow/core/platform/env.h                         |  2 +-
 tensorflow/core/platform/file_system.cc                |  2 +-
 tensorflow/core/platform/file_system.h                 |  2 +-
 tensorflow/core/platform/protobuf_internal.h           |  2 +-
 tensorflow/core/platform/windows/load_library.cc       |  2 +-
 24 files changed, 40 insertions(+), 25 deletions(-)

diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 1c85e9d0769..87b6fa1af1b 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -369,10 +369,10 @@ cc_library(
     name = "protobuf_internal",
     hdrs = ["protobuf_internal.h"],
     deps = [
+        ":errors",
         ":platform",
         ":protobuf",
         ":types",
-        "//tensorflow/core/lib/core:errors",
     ] + if_static(["@com_google_protobuf//:protobuf"]),
 )
 
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index e2d1f7b4e37..8c87b5c4bcf 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -178,6 +178,8 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:status",
         "//tensorflow/core/platform:stringpiece",
         "@curl",
     ],
@@ -197,6 +199,8 @@ cc_library(
         ":retrying_utils",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:status",
         "@com_google_absl//absl/strings",
         "@jsoncpp_git//:jsoncpp",
     ],
@@ -234,6 +238,8 @@ cc_library(
         ":compute_engine_metadata_client",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:status",
         "//tensorflow/core/platform:str_util",
     ],
 )
@@ -263,6 +269,8 @@ cc_library(
         ":http_request",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:status",
         "@boringssl//:crypto",
         "@jsoncpp_git//:jsoncpp",
     ],
@@ -348,6 +356,7 @@ tf_cc_test(
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:errors",
         "//tensorflow/core/platform:str_util",
     ],
 )
diff --git a/tensorflow/core/platform/cloud/auth_provider.h b/tensorflow/core/platform/cloud/auth_provider.h
index 4c219b70221..954c861169b 100644
--- a/tensorflow/core/platform/cloud/auth_provider.h
+++ b/tensorflow/core/platform/cloud/auth_provider.h
@@ -18,7 +18,7 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index ea007affac7..f53fce63750 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -20,6 +20,8 @@ limitations under the License.
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/platform/cloud/curl_http_request.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h
index 2e0e368a32b..b8e9aeb3399 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.h
+++ b/tensorflow/core/platform/cloud/curl_http_request.h
@@ -21,9 +21,9 @@ limitations under the License.
 #include 
 
 #include 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/cloud/http_request.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/status.h"
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index c461ab61313..8c4cb831346 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -27,7 +27,6 @@ limitations under the License.
 #endif
 #include "absl/base/macros.h"
 #include "include/json/json.h"
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/numbers.h"
@@ -38,6 +37,7 @@ limitations under the License.
 #include "tensorflow/core/platform/cloud/retrying_utils.h"
 #include "tensorflow/core/platform/cloud/time_util.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/str_util.h"
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 387c6391a15..71121afbd98 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/cloud/http_request_fake.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc
index e91a9f89757..bb52a5a7ca7 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider.cc
@@ -22,13 +22,14 @@ limitations under the License.
 #endif
 #include 
 #include 
+
 #include "absl/strings/match.h"
 #include "include/json/json.h"
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/base64.h"
 #include "tensorflow/core/platform/cloud/retrying_utils.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h
index 209e51407f9..5681293915f 100644
--- a/tensorflow/core/platform/cloud/http_request.h
+++ b/tensorflow/core/platform/cloud/http_request.h
@@ -20,8 +20,8 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/status.h"
diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h
index df0fe9eeb6b..9564aa7d30b 100644
--- a/tensorflow/core/platform/cloud/http_request_fake.h
+++ b/tensorflow/core/platform/cloud/http_request_fake.h
@@ -22,9 +22,9 @@ limitations under the License.
 #include 
 
 #include 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/status.h"
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index 89b1056be7d..69ba1f0926e 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -22,14 +22,15 @@ limitations under the License.
 #include 
 #endif
 #include 
+
 #include 
 #include 
 #include 
 #include 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/strings/base64.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h
index 5e85447fd3d..12bbc7d6abb 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/cloud/retrying_file_system.h
@@ -20,10 +20,10 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/cloud/retrying_utils.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/file_system.h"
 #include "tensorflow/core/platform/status.h"
 
diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/cloud/retrying_utils.cc
index 9c963dd82f2..1f0c41824bf 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils.cc
@@ -14,9 +14,10 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/retrying_utils.h"
-#include "tensorflow/core/lib/core/errors.h"
+
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/file_system.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/time_util.cc b/tensorflow/core/platform/cloud/time_util.cc
index afd06efa854..c780bd25e85 100644
--- a/tensorflow/core/platform/cloud/time_util.cc
+++ b/tensorflow/core/platform/cloud/time_util.cc
@@ -21,7 +21,7 @@ limitations under the License.
 #ifdef _WIN32
 #define timegm _mkgmtime
 #endif
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/platform/cloud/zone_provider.h b/tensorflow/core/platform/cloud/zone_provider.h
index 6f809ceb381..d1682fa81cc 100644
--- a/tensorflow/core/platform/cloud/zone_provider.h
+++ b/tensorflow/core/platform/cloud/zone_provider.h
@@ -18,7 +18,7 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/default/build_refactor.bzl b/tensorflow/core/platform/default/build_refactor.bzl
index eb12e6b5585..12246af4460 100644
--- a/tensorflow/core/platform/default/build_refactor.bzl
+++ b/tensorflow/core/platform/default/build_refactor.bzl
@@ -79,7 +79,6 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
             "//third_party/eigen3",
             "//tensorflow/core/lib/core:blocking_counter",
             "//tensorflow/core/lib/core:error_codes_proto_cc",
-            "//tensorflow/core/lib/core:errors",
             "//tensorflow/core/lib/core:stringpiece",
             "//tensorflow/core/lib/io:path",
             "//tensorflow/core/platform",
@@ -87,6 +86,7 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
             "//tensorflow/core/platform:cord",
             "//tensorflow/core/platform:denormal",
             "//tensorflow/core/platform:error",
+            "//tensorflow/core/platform:errors",
             "//tensorflow/core/platform:env_time",
             "//tensorflow/core/platform:file_statistics",
             "//tensorflow/core/platform:load_library",
@@ -131,7 +131,7 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
             "//tensorflow/core/platform:default/human_readable_json.cc",
         ],
         "deps": [
-            "//tensorflow/core/lib/core:errors",
+            "//tensorflow/core/platform:errors",
             "//tensorflow/core/platform:protobuf",
             "//tensorflow/core/platform:status",
             "//tensorflow/core/platform:strcat",
@@ -148,7 +148,7 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
             "//tensorflow/core/platform:default/load_library.cc",
         ],
         "deps": [
-            "//tensorflow/core/lib/core:errors",
+            "//tensorflow/core/platform:errors",
             "//tensorflow/core/platform:status",
         ],
         "visibility": ["//visibility:private"],
@@ -404,7 +404,6 @@ TF_WINDOWS_PLATFORM_LIBRARIES = {
             "//third_party/eigen3",
             "//tensorflow/core/lib/core:blocking_counter",
             "//tensorflow/core/lib/core:error_codes_proto_cc",
-            "//tensorflow/core/lib/core:errors",
             "//tensorflow/core/lib/core:stringpiece",
             "//tensorflow/core/lib/io:path",
             "//tensorflow/core/platform",
@@ -412,6 +411,7 @@ TF_WINDOWS_PLATFORM_LIBRARIES = {
             "//tensorflow/core/platform:cord",
             "//tensorflow/core/platform:denormal",
             "//tensorflow/core/platform:error",
+            "//tensorflow/core/platform:errors",
             "//tensorflow/core/platform:env_time",
             "//tensorflow/core/platform:file_statistics",
             "//tensorflow/core/platform:load_library",
@@ -457,7 +457,7 @@ TF_WINDOWS_PLATFORM_LIBRARIES = {
             "//tensorflow/core/platform:windows/load_library.cc",
         ],
         "deps": [
-            "//tensorflow/core/lib/core:errors",
+            "//tensorflow/core/platform:errors",
             "//tensorflow/core/platform:status",
             "//tensorflow/core/platform:windows_wide_char_impl",
         ],
diff --git a/tensorflow/core/platform/default/human_readable_json.cc b/tensorflow/core/platform/default/human_readable_json.cc
index 2ecbf437800..88ab9aa87fc 100644
--- a/tensorflow/core/platform/default/human_readable_json.cc
+++ b/tensorflow/core/platform/default/human_readable_json.cc
@@ -15,7 +15,7 @@ limitations under the License.
 
 #include "tensorflow/core/platform/human_readable_json.h"
 
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/strcat.h"
 
diff --git a/tensorflow/core/platform/default/load_library.cc b/tensorflow/core/platform/default/load_library.cc
index eaa68e66704..ef9edcc4501 100644
--- a/tensorflow/core/platform/default/load_library.cc
+++ b/tensorflow/core/platform/default/load_library.cc
@@ -17,7 +17,7 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 602915540ee..301b4c0e81e 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -21,9 +21,9 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env_time.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/host_info.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/protobuf.h"
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index cd6a6488e52..d5a22b1de2d 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -23,8 +23,8 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/env_time.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/file_system.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index 58d14e3d2d3..3a1d40e50e2 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -20,9 +20,9 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/strcat.h"
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index ade98a12637..caeedbffbc1 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -23,8 +23,8 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/cord.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/file_statistics.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/platform.h"
diff --git a/tensorflow/core/platform/protobuf_internal.h b/tensorflow/core/platform/protobuf_internal.h
index d0cfde09bc1..bf72968a157 100644
--- a/tensorflow/core/platform/protobuf_internal.h
+++ b/tensorflow/core/platform/protobuf_internal.h
@@ -17,7 +17,7 @@ limitations under the License.
 #define TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_
 
 #include "google/protobuf/any.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/platform/windows/load_library.cc b/tensorflow/core/platform/windows/load_library.cc
index 177253debdc..f95e770cc6b 100644
--- a/tensorflow/core/platform/windows/load_library.cc
+++ b/tensorflow/core/platform/windows/load_library.cc
@@ -25,7 +25,7 @@ limitations under the License.
 #undef LoadLibrary
 #undef ERROR
 
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/windows/wide_char.h"
 
 #pragma comment(lib, "Shlwapi.lib")

From eca32c17d058ed087e3851287074bba5160e605d Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Thu, 28 Nov 2019 01:02:50 -0800
Subject: [PATCH 088/279] compat: Update forward compatibility horizon to
 2019-11-28

PiperOrigin-RevId: 282898846
Change-Id: I7b2cbc16eb51e2f6fac09da8b8fdfc909cbd3221
---
 tensorflow/python/compat/compat.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 7c1d9f9f238..1f75de96001 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 11, 27)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 11, 28)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 

From 31013d5533618f2ed3375e4c068f8581f62b4cc8 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Thu, 28 Nov 2019 01:54:36 -0800
Subject: [PATCH 089/279] Add deprecation note for GL delegate.

PiperOrigin-RevId: 282904967
Change-Id: Ib980fbb84ffd2a507e1d14181ce7b056380fd614
---
 tensorflow/lite/delegates/gpu/BUILD         |  1 +
 tensorflow/lite/delegates/gpu/gl_delegate.h | 11 +++++++++++
 2 files changed, 12 insertions(+)

diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 83fa5872a0f..d73cf4686e3 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -37,6 +37,7 @@ cc_library(
         "//conditions:default": [],
     }),
     deps = [
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/types:span",
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite:minimal_logging",
diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.h b/tensorflow/lite/delegates/gpu/gl_delegate.h
index f1d30fd946e..bfc15fb120e 100644
--- a/tensorflow/lite/delegates/gpu/gl_delegate.h
+++ b/tensorflow/lite/delegates/gpu/gl_delegate.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include 
 
 #include 
+#include "absl/base/macros.h"
 #include "tensorflow/lite/c/common.h"
 
 #ifdef SWIG
@@ -39,6 +40,15 @@ limitations under the License.
 extern "C" {
 #endif  // __cplusplus
 
+// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
+//
+// GPU delegate declared in this file is OBSOLETE and replaced with the delegate
+// declared in delegate.h. New delegate combines all GL, CL and soon
+// Vulkan-based implementations in one.
+// Please migrate before end of 2019.
+//
+// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
+
 // LINT.IfChange
 enum TfLiteGlObjectType {
   TFLITE_GL_OBJECT_TYPE_FASTEST = 0,
@@ -109,6 +119,7 @@ TFL_CAPI_EXPORT TfLiteGpuDelegateOptions TfLiteGpuDelegateOptionsDefault();
 //   .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST,
 //   .dynamic_batch_enabled = false,
 // },
+ABSL_DEPRECATED("Use TfLiteGpuDelegateV2Create defined in delegate.h instead.")
 TFL_CAPI_EXPORT TfLiteDelegate* TfLiteGpuDelegateCreate(
     const TfLiteGpuDelegateOptions* options);
 

From 6596a600669d2dea9b3204e07e0943db071c42d1 Mon Sep 17 00:00:00 2001
From: Jose Ignacio Gomez 
Date: Thu, 28 Nov 2019 01:59:22 -0800
Subject: [PATCH 090/279] [Linalg] Change attribute n_loop_types to iterator

This addresses issue #270. Linalg is updated to take the same form
of iterator_types than vector contraction.

Closes #280

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/280 from tetuante:PRissue270 d26d88d090d3765d3b9884bfabdd023143f27287
PiperOrigin-RevId: 282905396
Change-Id: I1c55a92690dd31c28f9123b08dd482b52745681c
---
 .../mlir/xla/tests/lhlo-fuse-linalg.mlir      |  2 +-
 .../xla/transforms/lhlo_legalize_to_linalg.cc | 35 ++++++------
 .../Dialect/Linalg/IR/LinalgLibraryOps.td     | 57 +++++++++++--------
 3 files changed, 51 insertions(+), 43 deletions(-)

diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir
index 00c37da8c5e..2d23a5fb1f9 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir
@@ -1,7 +1,7 @@
 // RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
 
 #map0 = (d0, d1) -> (d0, d1)
-#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], n_loop_types = [2, 0, 0], n_views = [2, 1]}
+#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"], n_views = [2, 1]}
 func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
              %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
   %temp_result = alloc() {temp = true} : memref<2x2xf32>
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc
index 8920204abf3..28bacfa87f0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc
@@ -38,6 +38,15 @@ namespace mlir {
 namespace xla_lhlo {
 namespace {
 
+ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder b) {
+  auto parallelLoopTypeAttr = b.getStringAttr("parallel");
+  SmallVector iteratorTypes;
+  for (int i = 0; i < nParallelLoops; ++i) {
+    iteratorTypes.push_back(parallelLoopTypeAttr);
+  }
+  return b.getArrayAttr(iteratorTypes);
+}
+
 template 
 class PointwiseToLinalgConverter : public OpConversionPattern {
  public:
@@ -78,11 +87,6 @@ class PointwiseToLinalgConverter : public OpConversionPattern {
       result_or_body_arg.emplace_back(memrefType.getElementType());
     }
 
-    // Pointwise-ops have all surrounding loops parallel, so the loop triple is
-    // [argDim, 0, 0].
-    SmallVector loop_types{rewriter.getI64IntegerAttr(nloops),
-                                         rewriter.getI64IntegerAttr(0),
-                                         rewriter.getI64IntegerAttr(0)};
     // Define the number of input memref/output memrefs.
     SmallVector nmemrefs{
         rewriter.getI64IntegerAttr(bodyArgTypes.size()),
@@ -90,7 +94,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern {
 
     auto linalgOp = rewriter.create(
         loc, args, rewriter.getArrayAttr(indexingMaps),
-        rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs),
+        GetNParallelLoopsAttrs(nloops, rewriter),
+        rewriter.getArrayAttr(nmemrefs),
         /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);
 
     // Add a block to the region.
@@ -158,11 +163,6 @@ class BroadcastInDimConverter : public OpConversionPattern {
     indexingMaps.emplace_back(
         AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops)));
 
-    // Broadcast op has all surrounding loops parallel, so the loop triple is
-    // [argDim, 0, 0].
-    SmallVector loop_types{rewriter.getI64IntegerAttr(nloops),
-                                         rewriter.getI64IntegerAttr(0),
-                                         rewriter.getI64IntegerAttr(0)};
     // Define the number of input memref/output memrefs.
     SmallVector nmemrefs{
         rewriter.getI64IntegerAttr(bodyArgTypes.size()),
@@ -171,7 +171,8 @@ class BroadcastInDimConverter : public OpConversionPattern {
     auto loc = broadcastOp.getLoc();
     auto linalgOp = rewriter.create(
         loc, args, rewriter.getArrayAttr(indexingMaps),
-        rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs),
+        GetNParallelLoopsAttrs(nloops, rewriter),
+        rewriter.getArrayAttr(nmemrefs),
         /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);
 
     // Add a block to the region.
@@ -207,11 +208,6 @@ class IotaConverter : public OpConversionPattern {
     indexingMaps.emplace_back(
         AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops)));
 
-    // Pointwise-ops have all surrounding loops parallel, so the loop triple is
-    // [argDim, 0, 0].
-    SmallVector loop_types{rewriter.getI64IntegerAttr(nloops),
-                                         rewriter.getI64IntegerAttr(0),
-                                         rewriter.getI64IntegerAttr(0)};
     // Define the number of input memref/output memrefs.
     SmallVector nmemrefs{rewriter.getI64IntegerAttr(0),
                                        rewriter.getI64IntegerAttr(1)};
@@ -219,7 +215,8 @@ class IotaConverter : public OpConversionPattern {
     auto loc = iotaOp.getLoc();
     auto linalgOp = rewriter.create(
         loc, args, rewriter.getArrayAttr(indexingMaps),
-        rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs),
+        GetNParallelLoopsAttrs(nloops, rewriter),
+        rewriter.getArrayAttr(nmemrefs),
         /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);
 
     // Add a block to the region.
@@ -277,7 +274,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
 //     "linalg.yield"(%0) : (f32) -> ()
 //   }) {
 //     indexing_maps = [#map0, #map0, #map0],
-//     n_loop_types = [2, 0, 0],
+//     iterator_types = ["parallel", "parallel"],
 //     n_views = [2, 1]
 //   } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 // }
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
index 92b325b5943..e0070a8da35 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
@@ -368,7 +368,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
 class GenericOpBase : LinalgLibraryBase_Op {
   let arguments = (ins Variadic:$views,
                    AffineMapArrayAttr:$indexing_maps,
-                   I64ArrayAttr:$n_loop_types,
+                   ArrayAttr:$iterator_types,
                    I64ArrayAttr:$n_views,
                    OptionalAttr:$doc,
                    OptionalAttr:$fun,
@@ -377,7 +377,7 @@ class GenericOpBase : LinalgLibraryBase_Op {
   let extraClassDeclaration = [{
     SmallVector linalgTraitAttrNames() {
       return SmallVector{
-        "doc", "fun", "indexing_maps", "library_call", "n_loop_types", "n_views"
+        "doc", "fun", "indexing_maps", "library_call", "iterator_types", "n_views"
       };
     }
     unsigned getNumInputs() {
@@ -395,26 +395,35 @@ class GenericOpBase : LinalgLibraryBase_Op {
       return val.getZExtValue();
     }
     unsigned getNumParallelLoops() {
-      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+      if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
         return 0;
-      auto val = n_loop_types().getValue()[0].cast().getValue();
-      assert(val.getSExtValue() >= 0);
-      return val.getZExtValue();
+      unsigned nPar = 0;
+      for (auto ty : iterator_types()) {
+        if (ty.cast().getValue() == "parallel")
+          nPar++;
+      }
+      return nPar;
     }
     unsigned getNumReductionLoops() {
-      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+      if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
         return 0;
-      auto val = n_loop_types().getValue()[1].cast().getValue();
-      assert(val.getSExtValue() >= 0);
-      return val.getZExtValue();
-    }
-    unsigned getNumWindowLoops() {
-      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+      unsigned nRed = 0;
+      for (auto ty : iterator_types()) {
+        if (ty.cast().getValue() == "reduction")
+          nRed++;
+      }
+      return nRed;
+   }
+   unsigned getNumWindowLoops() {
+      if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
         return 0;
-      auto val = n_loop_types().getValue()[2].cast().getValue();
-      assert(val.getSExtValue() >= 0);
-      return val.getZExtValue();
-    }
+      unsigned nWin = 0;
+      for (auto ty : iterator_types()) {
+        if (ty.cast().getValue() == "window")
+          nWin++;
+      }
+      return nWin;
+   }
     unsigned getNumLoops() {
       return getNumParallelLoops() + getNumReductionLoops() +
         getNumWindowLoops();
@@ -474,8 +483,9 @@ def GenericOp : GenericOpBase<"generic"> {
         The external library is assumed to be dynamically linked and no strong
         compile-time guarantees are provided. In the absence of such a library
         call, linalg.generic will always lower to loops.
-      - n_loops: a triple of I64Attr representing the number of enclosing
-        [parallel, reduction, window] loops respectively.
+      - iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
+        the list represents and iterator of one of the following types:
+        parallel, reduction, window
       - n_views: a pair of I64Attr representing the number of input (readonly)
         and output (readwrite) views.
 
@@ -498,7 +508,7 @@ def GenericOp : GenericOpBase<"generic"> {
           indexing_maps = #matmul_accesses,
           library_call = "linalg_matmul",
           n_views = [2, 1],
-          n_loop_types = [2, 1, 0]
+          iterator_types = ["parallel", "parallel", "reduction"]
         }
       ```
 
@@ -568,8 +578,9 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
         maps to.  The external library is assumed to be dynamically linked and
         no strong compile-time guarantees are provided. In the absence of such
         a library call, linalg.indexed_generic will always lower to loops.
-      - n_loops: a triple of I64Attr representing the number of enclosing
-        [parallel, reduction, window] loops respectively.
+      - iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
+        the list represents and iterator of one of the following types:
+        parallel, reduction, window
       - n_views: a pair of I64Attr representing the number of input (readonly)
         and output (readwrite) views.
 
@@ -592,7 +603,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
           indexing_maps = #matmul_accesses,
           library_call = "linalg_matmul",
           n_views = [2, 1],
-          n_loop_types = [2, 1, 0]
+          iterator_types = ["parallel", "parallel", "reduction"]
         }
       ```
 

From b2ddf60f1dda8cf16ee78fd0a797dfcc815800c5 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Thu, 28 Nov 2019 02:23:02 -0800
Subject: [PATCH 091/279] Add deprecation note for GL delegate.

PiperOrigin-RevId: 282908052
Change-Id: I08cbc8d1dbf0b718d808b23649f1c457b93cd059
---
 tensorflow/lite/delegates/gpu/BUILD         |  1 -
 tensorflow/lite/delegates/gpu/gl_delegate.h | 11 -----------
 2 files changed, 12 deletions(-)

diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index d73cf4686e3..83fa5872a0f 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -37,7 +37,6 @@ cc_library(
         "//conditions:default": [],
     }),
     deps = [
-        "@com_google_absl//absl/base",
         "@com_google_absl//absl/types:span",
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite:minimal_logging",
diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.h b/tensorflow/lite/delegates/gpu/gl_delegate.h
index bfc15fb120e..f1d30fd946e 100644
--- a/tensorflow/lite/delegates/gpu/gl_delegate.h
+++ b/tensorflow/lite/delegates/gpu/gl_delegate.h
@@ -19,7 +19,6 @@ limitations under the License.
 #include 
 
 #include 
-#include "absl/base/macros.h"
 #include "tensorflow/lite/c/common.h"
 
 #ifdef SWIG
@@ -40,15 +39,6 @@ limitations under the License.
 extern "C" {
 #endif  // __cplusplus
 
-// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
-//
-// GPU delegate declared in this file is OBSOLETE and replaced with the delegate
-// declared in delegate.h. New delegate combines all GL, CL and soon
-// Vulkan-based implementations in one.
-// Please migrate before end of 2019.
-//
-// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
-
 // LINT.IfChange
 enum TfLiteGlObjectType {
   TFLITE_GL_OBJECT_TYPE_FASTEST = 0,
@@ -119,7 +109,6 @@ TFL_CAPI_EXPORT TfLiteGpuDelegateOptions TfLiteGpuDelegateOptionsDefault();
 //   .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST,
 //   .dynamic_batch_enabled = false,
 // },
-ABSL_DEPRECATED("Use TfLiteGpuDelegateV2Create defined in delegate.h instead.")
 TFL_CAPI_EXPORT TfLiteDelegate* TfLiteGpuDelegateCreate(
     const TfLiteGpuDelegateOptions* options);
 

From 426d3bbea497f606d4a9bee67de24fbd9e1848cf Mon Sep 17 00:00:00 2001
From: Adrian Kuegel 
Date: Thu, 28 Nov 2019 03:40:08 -0800
Subject: [PATCH 092/279] Also enable dot strength reduction for complex types.

There is already a runtime test that covers this:
DotOperationTest.MatrixVectorC64

PiperOrigin-RevId: 282917925
Change-Id: I1d05fc06fda9bb8c047308bbf34045fdec2a9c63
---
 .../compiler/xla/service/algebraic_simplifier.cc   |  8 ++++++--
 .../xla/service/algebraic_simplifier_test.cc       | 14 +++++++-------
 2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4b6b91af122..2fe8c309cb0 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1825,7 +1825,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
   // If the lhs or rhs have only batch and contracting dimensions, a dot can be
   // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
   if (options_.enable_dot_strength_reduction() &&
-      ShapeUtil::ElementIsFloating(dot->shape()) &&
+      (ShapeUtil::ElementIsFloating(dot->shape()) ||
+       ShapeUtil::ElementIsComplex(dot->shape())) &&
       ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
             dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
         lhs->shape().rank()) ||
@@ -1886,7 +1887,10 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
                         MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
     std::vector reduce_dims(
         dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
-    PrimitiveType dot_type = dot->shape().element_type() == F64 ? F64 : F32;
+    PrimitiveType dot_type =
+        ShapeUtil::ElementIsComplex(dot->shape())
+            ? dot->shape().element_type()
+            : dot->shape().element_type() == F64 ? F64 : F32;
     new_dot = AsType(new_dot, dot_type);
     const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
     absl::c_iota(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 2618a12673f..88282986560 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -4706,12 +4706,11 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
   EXPECT_EQ(has_no_dot, dot_should_be_transformed);
 }
 
-INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation,
-                         BatchDotStrengthReductionTest,
-                         ::testing::Combine(::testing::Values(-1, 1, 2),
-                                            ::testing::Values(-1, 1, 2),
-                                            ::testing::Values(-1, 1, 2),
-                                            ::testing::Values(F64, F32, BF16)));
+INSTANTIATE_TEST_SUITE_P(
+    BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest,
+    ::testing::Combine(::testing::Values(-1, 1, 2), ::testing::Values(-1, 1, 2),
+                       ::testing::Values(-1, 1, 2),
+                       ::testing::Values(C128, C64, F64, F32, BF16)));
 
 class DotStrengthReductionTest
     : public AlgebraicSimplifierTest,
@@ -4775,7 +4774,8 @@ INSTANTIATE_TEST_SUITE_P(
     DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
     ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
                        ::testing::Values(1, 2), ::testing::Bool(),
-                       ::testing::Bool(), ::testing::Values(F64, F32, BF16)));
+                       ::testing::Bool(),
+                       ::testing::Values(C128, C64, F64, F32, BF16)));
 
 struct DotOfConcatTestSpec {
   int64 m;

From d393702997fcf9beb9048778413d387ed30e296c Mon Sep 17 00:00:00 2001
From: Benjamin Kramer 
Date: Thu, 28 Nov 2019 06:24:54 -0800
Subject: [PATCH 093/279] [XLA:CPU] Register replica-id IR

Otherwise any use of replica-id as an argument to another operation crashes.

PiperOrigin-RevId: 282934610
Change-Id: Iab3957f85910fd1f453c6d1b9043c9d220ece633
---
 tensorflow/compiler/xla/service/cpu/ir_emitter.cc    | 1 +
 tensorflow/compiler/xla/tests/collective_ops_test.cc | 3 ++-
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 9510d1fecde..cf167a57087 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1511,6 +1511,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
 }
 
 Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
+  TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
   llvm::FunctionType* replica_id_function_ty =
       llvm::FunctionType::get(b_.getVoidTy(),
diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc
index a2de65502b4..42f687a7996 100644
--- a/tensorflow/compiler/xla/tests/collective_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc
@@ -397,7 +397,8 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
   const char* const kModuleStr = R"(
   HloModule test
   ENTRY test_computation {
-    ROOT id = u32[] replica-id()
+    id = u32[] replica-id()
+    ROOT out = u32[] copy(id)
   }
   )";
   const int64 kNumReplicas = 4;

From 75e5b5d70b6f33bd41fdf07b844c762b23f99d1b Mon Sep 17 00:00:00 2001
From: Adrian Kuegel 
Date: Thu, 28 Nov 2019 06:34:25 -0800
Subject: [PATCH 094/279] Try to avoid overflows in accumulation results.

This can be done by upcasting to an integer type with more bits.

PiperOrigin-RevId: 282935503
Change-Id: Iaf62534f9832ee93ed84e33b9df85068bc5e6941
---
 tensorflow/compiler/tf2xla/xla_helpers.cc | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 7bb1ad27467..74247bbaec7 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -111,6 +111,13 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
   if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
     return DT_FLOAT;
   }
+  // Upcast small integer types to 32 bit to avoid overflow.
+  if (dtype == DT_INT8 || dtype == DT_INT16) {
+    return DT_INT32;
+  }
+  if (dtype == DT_UINT8 || dtype == DT_UINT16) {
+    return DT_UINT32;
+  }
   return dtype;
 }
 

From b5bfebf6669982ccf818c3e9a69197ceca9dc456 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Thu, 28 Nov 2019 07:31:02 -0800
Subject: [PATCH 095/279] Comment style nitpick for StrategyExtendV2.

PiperOrigin-RevId: 282941111
Change-Id: Ide992cfe95dd02f5ecb1627c758326d577c49f9e
---
 tensorflow/python/distribute/distribute_lib.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 94ba10783f7..4a2a8af1840 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -1164,8 +1164,8 @@ class StrategyExtendedV2(object):
 
   *Replica context vs. Cross-replica context*
 
-  _replica context_ is when we are in some function that is being called once
-  for each replica.  Otherwise we are in cross-replica context, which is
+  A _replica context_ applies when we are in some function that is being called
+  once for each replica.  Otherwise we are in cross-replica context, which is
   useful for calling `tf.distribute.Strategy` methods which operate across the
   replicas (like `reduce_to()`). By default you start in a replica context
   (the "default single replica context") and then some methods can switch you

From 4b736990a70bad67c0307e8aeeaa397f10744d77 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Thu, 28 Nov 2019 08:44:05 -0800
Subject: [PATCH 096/279] [Grappler] Workaround for bug in
 HoistCWiseUnaryChainsStage where duplicate inputs to Concat can cause the
 rewrite to create cycles in the graphs.

PiperOrigin-RevId: 282949141
Change-Id: Ib8fa6571399331b9b5c58aa5bfbdda955a510603
---
 .../grappler/optimizers/arithmetic_optimizer.cc   | 15 +++++++++++++--
 1 file changed, 13 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d2ff480c29d..7f6940fb31d 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1455,7 +1455,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
     if (IsInPreserveSet(*node)) return false;
     if (IsConcat(*node) && node->attr().count("N") != 0) {
       const int n = node->attr().at("N").i();
-      return n > 1;
+      return n > 1 && FirstNInputsAreUnique(*node, n);
     } else if ((IsSplit(*node) || IsSplitV(*node)) &&
                node->attr().count("num_split") != 0) {
       const int num_split = node->attr().at("num_split").i();
@@ -1489,6 +1489,17 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
   }
 
  private:
+  bool FirstNInputsAreUnique(const NodeDef& node, int n) const {
+    if (n > node.input_size()) return false;
+    absl::flat_hash_set unique_inputs;
+    const int start = node.op() == "Concat" ? 1 : 0;
+    const int end = start + n;
+    for (int i = start; i < end; ++i) {
+      unique_inputs.insert(node.input(i));
+    }
+    return unique_inputs.size() == n;
+  }
+
   // Returns the length of the common unary chain of ops that can be
   // hoisted to the other side of concat or split.
   Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
@@ -1525,7 +1536,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
   Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
                            std::set* ctrl_inputs, NodeDef* root_node) {
     VLOG(3) << "Hoist unary op chain:"
-            << " root=" << root_node->name()
+            << " root=" << root_node->DebugString()
             << " prefix_length=" << prefix_length << " ctrl_inputs=["
             << absl::StrJoin(*ctrl_inputs, ", ") << "]";
 

From 8c34a9bd0676453dd0bffb016969a0515b2358e6 Mon Sep 17 00:00:00 2001
From: Stefano Galarraga 
Date: Thu, 28 Nov 2019 09:04:26 -0800
Subject: [PATCH 097/279] Refactors NnApiMock to extract a class to be used to
 do failure injection on NNAPI in native tests

PiperOrigin-RevId: 282951257
Change-Id: Ib6bee9bdf54c43de98b6d85fd855f786cb85c064
---
 tensorflow/lite/delegates/nnapi/BUILD         |   1 +
 .../nnapi/nnapi_delegate_mock_test.h          | 132 +-----------
 tensorflow/lite/nnapi/BUILD                   |  26 ++-
 tensorflow/lite/nnapi/nnapi_handler.cc        |  44 ++++
 tensorflow/lite/nnapi/nnapi_handler.h         | 197 ++++++++++++++++++
 tensorflow/lite/nnapi/nnapi_handler_test.cc   | 143 +++++++++++++
 6 files changed, 413 insertions(+), 130 deletions(-)
 create mode 100644 tensorflow/lite/nnapi/nnapi_handler.cc
 create mode 100644 tensorflow/lite/nnapi/nnapi_handler.h
 create mode 100644 tensorflow/lite/nnapi/nnapi_handler_test.cc

diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD
index 572adc5a0cc..0e99e1e3b79 100644
--- a/tensorflow/lite/delegates/nnapi/BUILD
+++ b/tensorflow/lite/delegates/nnapi/BUILD
@@ -103,6 +103,7 @@ cc_library(
     }),
     deps = [
         ":nnapi_delegate",
+        "//tensorflow/lite/nnapi:nnapi_handler",
         "//tensorflow/lite/nnapi:nnapi_implementation",
         "@com_google_absl//absl/memory",
         "@com_google_googletest//:gtest",
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
index 8551bdea0a8..24eb06edabe 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
@@ -28,134 +28,17 @@ limitations under the License.
 #include 
 #include "absl/memory/memory.h"
 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/nnapi/nnapi_handler.h"
 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
 
 namespace tflite {
 namespace delegate {
 namespace nnapi {
 
-class NnApiMock {
+class NnApiMock : public ::tflite::nnapi::NnApiHandler {
  public:
-  template 
-  void GetDeviceCountReturns() {
-    nnapi_->ANeuralNetworks_getDeviceCount = [](uint32_t* numDevices) -> int {
-      *numDevices = 2;
-      return Value;
-    };
-  }
-
-  template 
-  void ModelCreateReturns() {
-    nnapi_->ANeuralNetworksModel_create = [](ANeuralNetworksModel** model) {
-      *model = reinterpret_cast(1);
-      return Value;
-    };
-  }
-
-  template 
-  void AddOperandReturns() {
-    nnapi_->ANeuralNetworksModel_addOperand =
-        [](ANeuralNetworksModel* model,
-           const ANeuralNetworksOperandType* type) { return Value; };
-  }
-
-  template 
-  void SetOperandValueReturns() {
-    nnapi_->ANeuralNetworksModel_setOperandValue =
-        [](ANeuralNetworksModel* model, int32_t index, const void* buffer,
-           size_t length) { return Value; };
-  }
-
-  template 
-  void AddOperationReturns() {
-    nnapi_->ANeuralNetworksModel_addOperation =
-        [](ANeuralNetworksModel* model, ANeuralNetworksOperationType type,
-           uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
-           const uint32_t* outputs) { return Value; };
-  }
-
-  template 
-  void IdentifyInputAndOutputsReturns() {
-    nnapi_->ANeuralNetworksModel_identifyInputsAndOutputs =
-        [](ANeuralNetworksModel* model, uint32_t inputCount,
-           const uint32_t* inputs, uint32_t outputCount,
-           const uint32_t* outputs) { return Value; };
-  }
-
-  template 
-  void RelaxComputationFloatReturns() {
-    nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16 =
-        [](ANeuralNetworksModel* model, bool allow) { return Value; };
-  }
-
-  template 
-  void ModelFinishReturns() {
-    nnapi_->ANeuralNetworksModel_finish = [](ANeuralNetworksModel* model) {
-      return Value;
-    };
-  }
-
-  template 
-  void MemoryCreateFromFdReturns() {
-    nnapi_->ANeuralNetworksMemory_createFromFd =
-        [](size_t size, int protect, int fd, size_t offset,
-           ANeuralNetworksMemory** memory) {
-          *memory = reinterpret_cast(2);
-          return Value;
-        };
-  }
-
-  template 
-  void CompilationCreateReturns() {
-    nnapi_->ANeuralNetworksCompilation_create =
-        [](ANeuralNetworksModel* model,
-           ANeuralNetworksCompilation** compilation) {
-          *compilation = reinterpret_cast(3);
-          return Value;
-        };
-  }
-
-  template 
-  void CompilationFinishReturns() {
-    nnapi_->ANeuralNetworksCompilation_finish =
-        [](ANeuralNetworksCompilation* compilation) { return Value; };
-  }
-
-  template 
-  void ExecutionCreateReturns() {
-    nnapi_->ANeuralNetworksExecution_create =
-        [](ANeuralNetworksCompilation* compilation,
-           ANeuralNetworksExecution** execution) {
-          if (compilation == nullptr) return 1;
-          *execution = reinterpret_cast(4);
-          return Value;
-        };
-  }
-  template 
-  void ExecutionSetInputFromMemoryReturns() {
-    nnapi_->ANeuralNetworksExecution_setInputFromMemory =
-        [](ANeuralNetworksExecution* execution, int32_t index,
-           const ANeuralNetworksOperandType* type,
-           const ANeuralNetworksMemory* memory, size_t offset,
-           size_t length) { return Value; };
-  }
-  template 
-  void ExecutionSetOutputFromMemoryReturns() {
-    nnapi_->ANeuralNetworksExecution_setOutputFromMemory =
-        [](ANeuralNetworksExecution* execution, int32_t index,
-           const ANeuralNetworksOperandType* type,
-           const ANeuralNetworksMemory* memory, size_t offset,
-           size_t length) { return Value; };
-  }
-
-  template 
-  void ExecutionComputeReturns() {
-    nnapi_->ANeuralNetworksExecution_compute =
-        [](ANeuralNetworksExecution* execution) { return Value; };
-  }
-
   explicit NnApiMock(NnApi* nnapi, int android_sdk_version = 29)
-      : nnapi_(nnapi), prev_nnapi_(*nnapi) {
+      : ::tflite::nnapi::NnApiHandler(nnapi) {
     nnapi_->nnapi_exists = true;
     nnapi_->android_sdk_version = android_sdk_version;
 
@@ -186,14 +69,7 @@ class NnApiMock {
     ExecutionComputeReturns<0>();
   }
 
-  ~NnApiMock() {
-    // Restores global NNAPI to original value for non mocked tests
-    *nnapi_ = prev_nnapi_;
-  }
-
- private:
-  NnApi* nnapi_;
-  NnApi prev_nnapi_;
+  ~NnApiMock() { Reset(); }
 };
 
 class NnApiDelegateMockTest : public ::testing::Test {
diff --git a/tensorflow/lite/nnapi/BUILD b/tensorflow/lite/nnapi/BUILD
index e26d9567337..0a687e83131 100644
--- a/tensorflow/lite/nnapi/BUILD
+++ b/tensorflow/lite/nnapi/BUILD
@@ -57,7 +57,7 @@ cc_library(
         "//conditions:default": ["-lrt"],
     }),
     deps = [
-        "//tensorflow/lite/nnapi:nnapi_lib",
+        ":nnapi_lib",
     ],
 )
 
@@ -76,7 +76,29 @@ cc_test(
     name = "nnapi_implementation_test",
     srcs = ["nnapi_implementation_test.cc"],
     deps = [
-        "//tensorflow/lite/nnapi:nnapi_implementation",
+        ":nnapi_implementation",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_library(
+    name = "nnapi_handler",
+    srcs = ["nnapi_handler.cc"],
+    hdrs = ["nnapi_handler.h"],
+    deps = [
+        ":nnapi_implementation",
+        ":nnapi_lib",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/lite:framework",
+    ],
+)
+
+cc_test(
+    name = "nnapi_handler_test",
+    srcs = ["nnapi_handler_test.cc"],
+    deps = [
+        ":nnapi_handler",
+        ":nnapi_implementation",
         "@com_google_googletest//:gtest_main",
     ],
 )
diff --git a/tensorflow/lite/nnapi/nnapi_handler.cc b/tensorflow/lite/nnapi/nnapi_handler.cc
new file mode 100644
index 00000000000..354ad66463c
--- /dev/null
+++ b/tensorflow/lite/nnapi/nnapi_handler.cc
@@ -0,0 +1,44 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/nnapi/nnapi_handler.h"
+
+#include 
+
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
+
+namespace tflite {
+namespace nnapi {
+
+const NnApi* NnApiPassthroughInstance() {
+  static const NnApi orig_nnapi_copy = *NnApiImplementation();
+  return &orig_nnapi_copy;
+}
+
+// static
+NnApiHandler* NnApiHandler::Instance() {
+  // Ensuring that the original copy of nnapi is saved before we return
+  // access to NnApiHandler
+  NnApiPassthroughInstance();
+  static NnApiHandler handler{const_cast(NnApiImplementation())};
+  return &handler;
+}
+
+void NnApiHandler::Reset() {
+  // Restores global NNAPI to original value
+  *nnapi_ = *NnApiPassthroughInstance();
+}
+
+}  // namespace nnapi
+}  // namespace tflite
diff --git a/tensorflow/lite/nnapi/nnapi_handler.h b/tensorflow/lite/nnapi/nnapi_handler.h
new file mode 100644
index 00000000000..70406ba2c6e
--- /dev/null
+++ b/tensorflow/lite/nnapi/nnapi_handler.h
@@ -0,0 +1,197 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_NNAPI_NNAPI_HANDLER_H_
+#define TENSORFLOW_LITE_NNAPI_NNAPI_HANDLER_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
+
+namespace tflite {
+namespace nnapi {
+
+// Offers an interface to alter the behaviour of the NNAPI instance.
+// As for NNAPI, it is designed to be a singleton.
+// It allows to change the behaviour of some of the methods with some stub
+// implementation and then to reset the behavior to the original one using
+// Reset().
+//
+class NnApiHandler {
+ public:
+  // No destructor defined to allow this class to be used as singleton.
+
+  // Factory method, only one instance per process/jni library.
+  static NnApiHandler* Instance();
+
+  // Makes the current object a transparent proxy again, resetting any
+  // applied changes to its methods.
+  void Reset();
+
+  // Using templates in the ...Returns methods because the functions need to be
+  // stateless and the template generated code is more readable than using a
+  // file-local variable in the method implementation to store the configured
+  // result.
+
+  template 
+  void GetDeviceCountReturns() {
+    nnapi_->ANeuralNetworks_getDeviceCount = [](uint32_t* numDevices) -> int {
+      *numDevices = 2;
+      return Value;
+    };
+  }
+
+  void StubGetDeviceCountWith(int(stub)(uint32_t*)) {
+    nnapi_->ANeuralNetworks_getDeviceCount = stub;
+  }
+
+  template 
+  void ModelCreateReturns() {
+    nnapi_->ANeuralNetworksModel_create = [](ANeuralNetworksModel** model) {
+      *model = reinterpret_cast(1);
+      return Value;
+    };
+  }
+
+  template 
+  void AddOperandReturns() {
+    nnapi_->ANeuralNetworksModel_addOperand =
+        [](ANeuralNetworksModel* model,
+           const ANeuralNetworksOperandType* type) { return Value; };
+  }
+
+  template 
+  void SetOperandValueReturns() {
+    nnapi_->ANeuralNetworksModel_setOperandValue =
+        [](ANeuralNetworksModel* model, int32_t index, const void* buffer,
+           size_t length) { return Value; };
+  }
+
+  template 
+  void AddOperationReturns() {
+    nnapi_->ANeuralNetworksModel_addOperation =
+        [](ANeuralNetworksModel* model, ANeuralNetworksOperationType type,
+           uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
+           const uint32_t* outputs) { return Value; };
+  }
+
+  template 
+  void IdentifyInputAndOutputsReturns() {
+    nnapi_->ANeuralNetworksModel_identifyInputsAndOutputs =
+        [](ANeuralNetworksModel* model, uint32_t inputCount,
+           const uint32_t* inputs, uint32_t outputCount,
+           const uint32_t* outputs) { return Value; };
+  }
+
+  template 
+  void RelaxComputationFloatReturns() {
+    nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16 =
+        [](ANeuralNetworksModel* model, bool allow) { return Value; };
+  }
+
+  template 
+  void ModelFinishReturns() {
+    nnapi_->ANeuralNetworksModel_finish = [](ANeuralNetworksModel* model) {
+      return Value;
+    };
+  }
+
+  template 
+  void MemoryCreateFromFdReturns() {
+    nnapi_->ANeuralNetworksMemory_createFromFd =
+        [](size_t size, int protect, int fd, size_t offset,
+           ANeuralNetworksMemory** memory) {
+          *memory = reinterpret_cast(2);
+          return Value;
+        };
+  }
+
+  template 
+  void CompilationCreateReturns() {
+    nnapi_->ANeuralNetworksCompilation_create =
+        [](ANeuralNetworksModel* model,
+           ANeuralNetworksCompilation** compilation) {
+          *compilation = reinterpret_cast(3);
+          return Value;
+        };
+  }
+
+  template 
+  void CompilationFinishReturns() {
+    nnapi_->ANeuralNetworksCompilation_finish =
+        [](ANeuralNetworksCompilation* compilation) { return Value; };
+  }
+
+  template 
+  void ExecutionCreateReturns() {
+    nnapi_->ANeuralNetworksExecution_create =
+        [](ANeuralNetworksCompilation* compilation,
+           ANeuralNetworksExecution** execution) {
+          if (compilation == nullptr) return 1;
+          *execution = reinterpret_cast(4);
+          return Value;
+        };
+  }
+  template 
+  void ExecutionSetInputFromMemoryReturns() {
+    nnapi_->ANeuralNetworksExecution_setInputFromMemory =
+        [](ANeuralNetworksExecution* execution, int32_t index,
+           const ANeuralNetworksOperandType* type,
+           const ANeuralNetworksMemory* memory, size_t offset,
+           size_t length) { return Value; };
+  }
+  template 
+  void ExecutionSetOutputFromMemoryReturns() {
+    nnapi_->ANeuralNetworksExecution_setOutputFromMemory =
+        [](ANeuralNetworksExecution* execution, int32_t index,
+           const ANeuralNetworksOperandType* type,
+           const ANeuralNetworksMemory* memory, size_t offset,
+           size_t length) { return Value; };
+  }
+
+  template 
+  void ExecutionComputeReturns() {
+    nnapi_->ANeuralNetworksExecution_compute =
+        [](ANeuralNetworksExecution* execution) { return Value; };
+  }
+
+ protected:
+  explicit NnApiHandler(NnApi* nnapi) : nnapi_(nnapi) { DCHECK(nnapi); }
+
+  NnApi* nnapi_;
+};
+
+// Returns a pointer to an unaltered instance of NNAPI. Is intended
+// to be used by stub methods when wanting to pass-through to original
+// implementation for example:
+//
+// NnApiTestUtility()->StubGetDeviceWith(
+//  [](uint32_t devIndex, ANeuralNetworksDevice** device) -> int {
+//        static int count = 0;
+//        if (count++ < 1) {
+//          NnApiPassthroughInstance()->ANeuralNetworks_getDevice(
+//                devIndex, device);
+//        } else {
+//            return ANEURALNETWORKS_BAD_DATA;
+//        }
+//   });
+const NnApi* NnApiPassthroughInstance();
+
+// Returns an instance of NnApiProxy that can be used to alter
+// the behaviour of the TFLite wide instance of NnApi.
+NnApiHandler* NnApiProxyInstance();
+
+}  // namespace nnapi
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_NNAPI_NNAPI_HANDLER_H_
diff --git a/tensorflow/lite/nnapi/nnapi_handler_test.cc b/tensorflow/lite/nnapi/nnapi_handler_test.cc
new file mode 100644
index 00000000000..aea766ef036
--- /dev/null
+++ b/tensorflow/lite/nnapi/nnapi_handler_test.cc
@@ -0,0 +1,143 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/nnapi/nnapi_handler.h"
+
+#include 
+#include 
+
+#include 
+#include 
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
+
+namespace tflite {
+namespace nnapi {
+
+using testing::Eq;
+using testing::Ne;
+using testing::NotNull;
+
+void ExpectEquals(const NnApi& left, const NnApi& right);
+
+class NnApiHandlerTest : public ::testing::Test {
+ protected:
+  ~NnApiHandlerTest() override { NnApiHandler::Instance()->Reset(); }
+};
+
+TEST_F(NnApiHandlerTest, ShouldAlterNnApiInstanceBehaviour) {
+  const NnApi* nnapi = NnApiImplementation();
+
+  const auto device_count_stub = [](uint32_t* device_count) -> int {
+    *device_count = 999;
+    return ANEURALNETWORKS_NO_ERROR;
+  };
+
+  NnApiHandler::Instance()->StubGetDeviceCountWith(device_count_stub);
+
+  ASSERT_THAT(nnapi->ANeuralNetworks_getDeviceCount, NotNull());
+
+  uint32_t device_count = 0;
+  nnapi->ANeuralNetworks_getDeviceCount(&device_count);
+  EXPECT_THAT(device_count, Eq(999));
+}
+
+TEST_F(NnApiHandlerTest, ShouldRestoreNnApiToItsOriginalValueWithReset) {
+  NnApi nnapi_orig_copy = *NnApiImplementation();
+
+  auto device_count_override = [](uint32_t* device_count) -> int {
+    *device_count = 777;
+    return ANEURALNETWORKS_NO_ERROR;
+  };
+
+  NnApiHandler::Instance()->StubGetDeviceCountWith(device_count_override);
+
+  EXPECT_THAT(nnapi_orig_copy.ANeuralNetworks_getDeviceCount,
+              Ne(NnApiImplementation()->ANeuralNetworks_getDeviceCount));
+
+  NnApiHandler::Instance()->Reset();
+
+  ExpectEquals(nnapi_orig_copy, *NnApiImplementation());
+}
+
+int (*device_count_ptr)(uint32_t*);
+TEST_F(NnApiHandlerTest, ShouldSupportPassthroughCalls) {
+  const NnApi* nnapi = NnApiImplementation();
+  device_count_ptr = nnapi->ANeuralNetworks_getDeviceCount;
+
+  NnApiHandler::Instance()->StubGetDeviceCountWith(
+      [](uint32_t* device_count) -> int {
+        return NnApiPassthroughInstance()->ANeuralNetworks_getDeviceCount ==
+               device_count_ptr;
+      });
+
+  uint32_t device_count = 0;
+  EXPECT_THAT(nnapi->ANeuralNetworks_getDeviceCount(&device_count), Eq(1));
+}
+
+void ExpectEquals(const NnApi& left, const NnApi& right) {
+#define EXPECT_NNAPI_MEMBER_EQ(name) EXPECT_EQ(left.name, right.name)
+
+  EXPECT_NNAPI_MEMBER_EQ(nnapi_exists);
+  EXPECT_NNAPI_MEMBER_EQ(android_sdk_version);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksMemory_createFromFd);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksMemory_free);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_create);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_free);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_finish);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_addOperand);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_setOperandValue);
+  EXPECT_NNAPI_MEMBER_EQ(
+      ANeuralNetworksModel_setOperandSymmPerChannelQuantParams);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_setOperandValueFromMemory);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_addOperation);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_identifyInputsAndOutputs);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_relaxComputationFloat32toFloat16);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_create);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_free);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_setPreference);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_finish);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_create);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_free);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setInput);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setInputFromMemory);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setOutput);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setOutputFromMemory);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_startCompute);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksEvent_wait);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksEvent_free);
+  EXPECT_NNAPI_MEMBER_EQ(ASharedMemory_create);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworks_getDeviceCount);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworks_getDevice);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksDevice_getName);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksDevice_getVersion);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksDevice_getFeatureLevel);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksDevice_getType);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_getSupportedOperationsForDevices);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_createForDevices);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_setCaching);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_compute);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_getOutputOperandRank);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_getOutputOperandDimensions);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksBurst_create);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksBurst_free);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_burstCompute);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksMemory_createFromAHardwareBuffer);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setMeasureTiming);
+  EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_getDuration);
+
+#undef EXPECT_NNAPI_MEMBER_EQ
+}
+
+}  // namespace nnapi
+}  // namespace tflite

From 0ffee12083b28255284107621dc79fb71e6c1b1e Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Thu, 28 Nov 2019 11:50:47 -0800
Subject: [PATCH 098/279] Split out FunctionLike printing/parsing into
 FunctionImplementation.{h,cpp}

Helper utilies for parsing and printing FunctionLike Ops are only relevant to
the implementation of the Op, not its definition. They depend on
OpImplementation.h and increase the inclusion footprint of FunctionSupport.h,
and do so only to provide some utilities in the "impl" namespace. Move them to
a separate files, similarly to OpDefinition/OpImplementation distinction, and
make only Op implementations use them while keeping headers cleaner. NFC.

PiperOrigin-RevId: 282964556
Change-Id: I265b4f58fafeb5cedf42697b5dca66892fec26d7
---
 third_party/mlir/BUILD                        |   3 +-
 .../include/mlir/IR/FunctionImplementation.h  | 109 ++++++++++++++++++
 .../mlir/include/mlir/IR/FunctionSupport.h    |  73 +-----------
 .../mlir/lib/Dialect/GPU/IR/GPUDialect.cpp    |   1 +
 .../lib/Dialect/LLVMIR/IR/LLVMDialect.cpp     |   1 +
 third_party/mlir/lib/IR/Function.cpp          |   1 +
 ...Support.cpp => FunctionImplementation.cpp} |   6 +-
 7 files changed, 118 insertions(+), 76 deletions(-)
 create mode 100644 third_party/mlir/include/mlir/IR/FunctionImplementation.h
 rename third_party/mlir/lib/IR/{FunctionSupport.cpp => FunctionImplementation.cpp} (99%)

diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 57893543f6f..76aecb2088c 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -79,7 +79,7 @@ cc_library(
         "lib/IR/Diagnostics.cpp",
         "lib/IR/Dialect.cpp",
         "lib/IR/Function.cpp",
-        "lib/IR/FunctionSupport.cpp",
+        "lib/IR/FunctionImplementation.cpp",
         "lib/IR/IntegerSet.cpp",
         "lib/IR/IntegerSetDetail.h",
         "lib/IR/Location.cpp",
@@ -114,6 +114,7 @@ cc_library(
         "include/mlir/IR/DialectImplementation.h",
         "include/mlir/IR/DialectInterface.h",
         "include/mlir/IR/Function.h",
+        "include/mlir/IR/FunctionImplementation.h",
         "include/mlir/IR/FunctionSupport.h",
         "include/mlir/IR/Identifier.h",
         "include/mlir/IR/IntegerSet.h",
diff --git a/third_party/mlir/include/mlir/IR/FunctionImplementation.h b/third_party/mlir/include/mlir/IR/FunctionImplementation.h
new file mode 100644
index 00000000000..241d5615acf
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/FunctionImplementation.h
@@ -0,0 +1,109 @@
+//===- FunctionImplementation.h - Function-like Op utilities ----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file provides utility functions for implementing function-like
+// operations, in particular, parsing, printing and verification components
+// common to function-like operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_FUNCTIONIMPLEMENTATION_H_
+#define MLIR_IR_FUNCTIONIMPLEMENTATION_H_
+
+#include "mlir/IR/FunctionSupport.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+
+namespace impl {
+
+/// A named class for passing around the variadic flag.
+class VariadicFlag {
+public:
+  explicit VariadicFlag(bool variadic) : variadic(variadic) {}
+  bool isVariadic() const { return variadic; }
+
+private:
+  /// Underlying storage.
+  bool variadic;
+};
+
+/// Adds argument and result attributes, provided as `argAttrs` and
+/// `resultAttrs` arguments, to the list of operation attributes in `result`.
+/// Internally, argument and result attributes are stored as dict attributes
+/// with special names given by getResultAttrName, getArgumentAttrName.
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef> argAttrs,
+                          ArrayRef> resultAttrs);
+
+/// Callback type for `parseFunctionLikeOp`, the callback should produce the
+/// type that will be associated with a function-like operation from lists of
+/// function arguments and results, VariadicFlag indicates whether the function
+/// should have variadic arguments; in case of error, it may populate the last
+/// argument with a message.
+using FuncTypeBuilder = llvm::function_ref, ArrayRef, VariadicFlag, std::string &)>;
+
+/// Parses a function signature using `parser`. The `allowVariadic` argument
+/// indicates whether functions with variadic arguments are supported. The
+/// trailing arguments are populated by this function with names, types and
+/// attributes of the arguments and those of the results.
+ParseResult parseFunctionSignature(
+    OpAsmParser &parser, bool allowVariadic,
+    SmallVectorImpl &argNames,
+    SmallVectorImpl &argTypes,
+    SmallVectorImpl> &argAttrs, bool &isVariadic,
+    SmallVectorImpl &resultTypes,
+    SmallVectorImpl> &resultAttrs);
+
+/// Parser implementation for function-like operations.  Uses
+/// `funcTypeBuilder` to construct the custom function type given lists of
+/// input and output types.  If `allowVariadic` is set, the parser will accept
+/// trailing ellipsis in the function signature and indicate to the builder
+/// whether the function is variadic.  If the builder returns a null type,
+/// `result` will not contain the `type` attribute.  The caller can then add a
+/// type, report the error or delegate the reporting to the op's verifier.
+ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
+                                bool allowVariadic,
+                                FuncTypeBuilder funcTypeBuilder);
+
+/// Printer implementation for function-like operations.  Accepts lists of
+/// argument and result types to use while printing.
+void printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
+                         ArrayRef argTypes, bool isVariadic,
+                         ArrayRef resultTypes);
+
+/// Prints the signature of the function-like operation `op`.  Assumes `op` has
+/// the FunctionLike trait and passed the verification.
+void printFunctionSignature(OpAsmPrinter &p, Operation *op,
+                            ArrayRef argTypes, bool isVariadic,
+                            ArrayRef resultTypes);
+
+/// Prints the list of function prefixed with the "attributes" keyword. The
+/// attributes with names listed in "elided" as well as those used by the
+/// function-like operation internally are not printed. Nothing is printed
+/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
+/// passed the verification.
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
+                             unsigned numResults,
+                             ArrayRef elided = {});
+
+} // namespace impl
+
+} // namespace mlir
+
+#endif // MLIR_IR_FUNCTIONIMPLEMENTATION_H_
diff --git a/third_party/mlir/include/mlir/IR/FunctionSupport.h b/third_party/mlir/include/mlir/IR/FunctionSupport.h
index 38e406e8f08..4656c35a9c2 100644
--- a/third_party/mlir/include/mlir/IR/FunctionSupport.h
+++ b/third_party/mlir/include/mlir/IR/FunctionSupport.h
@@ -24,12 +24,12 @@
 #define MLIR_IR_FUNCTIONSUPPORT_H
 
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
 #include "llvm/ADT/SmallString.h"
 
 namespace mlir {
 
 namespace impl {
+
 /// Return the name of the attribute used for function types.
 inline StringRef getTypeAttrName() { return "type"; }
 
@@ -73,77 +73,6 @@ inline ArrayRef getResultAttrs(Operation *op, unsigned index) {
   return resultDict ? resultDict.getValue() : llvm::None;
 }
 
-/// A named class for passing around the variadic flag.
-class VariadicFlag {
-public:
-  explicit VariadicFlag(bool variadic) : variadic(variadic) {}
-  bool isVariadic() const { return variadic; }
-
-private:
-  /// Underlying storage.
-  bool variadic;
-};
-
-/// Adds argument and result attributes, provided as `argAttrs` and
-/// `resultAttrs` arguments, to the list of operation attributes in `result`.
-/// Internally, argument and result attributes are stored as dict attributes
-/// with special names given by getResultAttrName, getArgumentAttrName.
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef> argAttrs,
-                          ArrayRef> resultAttrs);
-
-/// Callback type for `parseFunctionLikeOp`, the callback should produce the
-/// type that will be associated with a function-like operation from lists of
-/// function arguments and results, VariadicFlag indicates whether the function
-/// should have variadic arguments; in case of error, it may populate the last
-/// argument with a message.
-using FuncTypeBuilder = llvm::function_ref, ArrayRef, VariadicFlag, std::string &)>;
-
-/// Parses a function signature using `parser`. The `allowVariadic` argument
-/// indicates whether functions with variadic arguments are supported. The
-/// trailing arguments are populated by this function with names, types and
-/// attributes of the arguments and those of the results.
-ParseResult parseFunctionSignature(
-    OpAsmParser &parser, bool allowVariadic,
-    SmallVectorImpl &argNames,
-    SmallVectorImpl &argTypes,
-    SmallVectorImpl> &argAttrs, bool &isVariadic,
-    SmallVectorImpl &resultTypes,
-    SmallVectorImpl> &resultAttrs);
-
-/// Parser implementation for function-like operations.  Uses
-/// `funcTypeBuilder` to construct the custom function type given lists of
-/// input and output types.  If `allowVariadic` is set, the parser will accept
-/// trailing ellipsis in the function signature and indicate to the builder
-/// whether the function is variadic.  If the builder returns a null type,
-/// `result` will not contain the `type` attribute.  The caller can then add a
-/// type, report the error or delegate the reporting to the op's verifier.
-ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
-                                bool allowVariadic,
-                                FuncTypeBuilder funcTypeBuilder);
-
-/// Printer implementation for function-like operations.  Accepts lists of
-/// argument and result types to use while printing.
-void printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
-                         ArrayRef argTypes, bool isVariadic,
-                         ArrayRef resultTypes);
-
-/// Prints the signature of the function-like operation `op`.  Assumes `op` has
-/// the FunctionLike trait and passed the verification.
-void printFunctionSignature(OpAsmPrinter &p, Operation *op,
-                            ArrayRef argTypes, bool isVariadic,
-                            ArrayRef resultTypes);
-
-/// Prints the list of function prefixed with the "attributes" keyword. The
-/// attributes with names listed in "elided" as well as those used by the
-/// function-like operation internally are not printed. Nothing is printed
-/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
-/// passed the verification.
-void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
-                             unsigned numResults,
-                             ArrayRef elided = {});
-
 } // namespace impl
 
 namespace OpTrait {
diff --git a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5fc1cade760..8d84fadae8a 100644
--- a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
+#include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
diff --git a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ca71db5fd8d..66a9bc0ae9f 100644
--- a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StandardTypes.h"
diff --git a/third_party/mlir/lib/IR/Function.cpp b/third_party/mlir/lib/IR/Function.cpp
index 4e103508af0..e5e854260f3 100644
--- a/third_party/mlir/lib/IR/Function.cpp
+++ b/third_party/mlir/lib/IR/Function.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
diff --git a/third_party/mlir/lib/IR/FunctionSupport.cpp b/third_party/mlir/lib/IR/FunctionImplementation.cpp
similarity index 99%
rename from third_party/mlir/lib/IR/FunctionSupport.cpp
rename to third_party/mlir/lib/IR/FunctionImplementation.cpp
index c6f2673ef2a..a1fc21e11ea 100644
--- a/third_party/mlir/lib/IR/FunctionSupport.cpp
+++ b/third_party/mlir/lib/IR/FunctionImplementation.cpp
@@ -1,4 +1,4 @@
-//===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
+//===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -15,9 +15,9 @@
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/IR/FunctionSupport.h"
+#include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/FunctionSupport.h"
 #include "mlir/IR/SymbolTable.h"
 
 using namespace mlir;

From 4cc04407f206fd3b340e5b6b8df12f6157dda4c4 Mon Sep 17 00:00:00 2001
From: Brian Zhao 
Date: Thu, 28 Nov 2019 13:05:46 -0800
Subject: [PATCH 099/279] Adding tensorflow/core/lib/monitoring/BUILD. This
 change is part of the refactoring described in Tensorflow Build Improvements
 RFC: https://github.com/tensorflow/community/pull/179

PiperOrigin-RevId: 282970277
Change-Id: I1430d0beea8917bd4fb167f30c740e2e9329b3cc
---
 tensorflow/core/BUILD                         |  33 +--
 tensorflow/core/lib/histogram/BUILD           |   2 +
 tensorflow/core/lib/monitoring/BUILD          | 195 ++++++++++++++++++
 .../core/lib/monitoring/collection_registry.h |   2 +-
 tensorflow/core/lib/monitoring/metric_def.h   |   2 +-
 5 files changed, 218 insertions(+), 16 deletions(-)
 create mode 100644 tensorflow/core/lib/monitoring/BUILD

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 74859e11a79..29fe1c92932 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -501,12 +501,6 @@ cc_library(
 cc_library(
     name = "lib",
     hdrs = [
-        "lib/monitoring/collected_metrics.h",
-        "lib/monitoring/collection_registry.h",
-        "lib/monitoring/counter.h",
-        "lib/monitoring/gauge.h",
-        "lib/monitoring/metric_def.h",
-        "lib/monitoring/sampler.h",
         ":platform_base_hdrs",
         ":platform_env_hdrs",
         ":platform_file_system_hdrs",
@@ -520,6 +514,7 @@ cc_library(
         "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_headers",
         "//tensorflow/core/lib/io:legacy_lib_io_headers",
         "//tensorflow/core/lib/math:math_util.h",
+        "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_lib_headers",
         "//tensorflow/core/lib/random:legacy_lib_random_headers",
         "//tensorflow/core/lib/strings:legacy_lib_string_headers",
     ],
@@ -1548,6 +1543,8 @@ filegroup(
         "//tensorflow/core/lib/io:legacy_lib_io_all_headers",
         "//tensorflow/core/lib/io:legacy_lib_io_all_srcs",
         "//tensorflow/core/lib/math:math_util.h",
+        "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_all_headers",
+        "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_all_srcs",
         "//tensorflow/core/lib/random:legacy_lib_random_all_headers",
         "//tensorflow/core/lib/random:legacy_lib_random_all_srcs",
         "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers",
@@ -2095,6 +2092,7 @@ LIB_INTERNAL_PRIVATE_HEADERS = [
     "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers",
     "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_headers",
     "//tensorflow/core/lib/hash:legacy_lib_hash_all_headers",
+    "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_all_headers",
     "//tensorflow/core/lib/io:legacy_lib_io_all_headers",
     "//tensorflow/core/lib/random:legacy_lib_random_all_headers",
     "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers",
@@ -2116,9 +2114,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = [
     "//tensorflow/core/lib/gtl:legacy_lib_internal_public_gtl_headers",
     "//tensorflow/core/lib/hash:legacy_lib_internal_public_headers",
     "//tensorflow/core/lib/io:legacy_lib_internal_public_headers",
-    "lib/monitoring/mobile_counter.h",
-    "lib/monitoring/mobile_gauge.h",
-    "lib/monitoring/mobile_sampler.h",
+    "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_lib_internal_public_headers",
     "//tensorflow/core/lib/random:legacy_lib_internal_public_random_headers",
     "//tensorflow/core/lib/strings:legacy_lib_internal_public_string_headers",
     "lib/wav/wav_io.h",
@@ -2240,6 +2236,15 @@ cc_library(
         "//tensorflow/core/lib/io:zlib_inputstream",
         "//tensorflow/core/lib/io:zlib_outputbuffer",
         "//tensorflow/core/lib/math:math_util",
+        "//tensorflow/core/lib/monitoring:collected_metrics",
+        "//tensorflow/core/lib/monitoring:collection_registry",
+        "//tensorflow/core/lib/monitoring:counter",
+        "//tensorflow/core/lib/monitoring:gauge",
+        "//tensorflow/core/lib/monitoring:metric_def",
+        "//tensorflow/core/lib/monitoring:mobile_counter",
+        "//tensorflow/core/lib/monitoring:mobile_gauge",
+        "//tensorflow/core/lib/monitoring:mobile_sampler",
+        "//tensorflow/core/lib/monitoring:sampler",
         "//tensorflow/core/lib/random:exact_uniform_int",
         "//tensorflow/core/lib/random:philox",
         "//tensorflow/core/lib/random:philox_random",
@@ -3455,11 +3460,6 @@ tf_cc_tests(
     name = "low_level_library_tests",
     size = "small",
     srcs = [
-        "lib/monitoring/collection_registry_test.cc",
-        "lib/monitoring/counter_test.cc",
-        "lib/monitoring/gauge_test.cc",
-        "lib/monitoring/metric_def_test.cc",
-        "lib/monitoring/sampler_test.cc",
         "lib/wav/wav_io_test.cc",
         "//tensorflow/core/lib/core:legacy_lib_core_all_tests",
         "//tensorflow/core/lib/gtl:legacy_lib_gtl_tests",
@@ -3467,6 +3467,11 @@ tf_cc_tests(
         "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_tests",
         "//tensorflow/core/lib/io:legacy_lib_io_all_tests",
         "//tensorflow/core/lib/math:math_util_test.cc",
+        "//tensorflow/core/lib/monitoring:collection_registry_test.cc",
+        "//tensorflow/core/lib/monitoring:counter_test.cc",
+        "//tensorflow/core/lib/monitoring:gauge_test.cc",
+        "//tensorflow/core/lib/monitoring:metric_def_test.cc",
+        "//tensorflow/core/lib/monitoring:sampler_test.cc",
         "//tensorflow/core/lib/random:legacy_lib_random_tests",
         "//tensorflow/core/lib/strings:legacy_low_level_library_tests",
         "//tensorflow/core/platform:fingerprint_test.cc",
diff --git a/tensorflow/core/lib/histogram/BUILD b/tensorflow/core/lib/histogram/BUILD
index 5eba33b0430..9108a09dd15 100644
--- a/tensorflow/core/lib/histogram/BUILD
+++ b/tensorflow/core/lib/histogram/BUILD
@@ -2,6 +2,8 @@ package(
     default_visibility = [
         # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/**
         "//tensorflow/core:__pkg__",
+        # tensorflow/core/lib/monitoring:sampler uses histogram
+        "//tensorflow/core/lib/monitoring:__pkg__",
     ],
     licenses = ["notice"],  # Apache 2.0
 )
diff --git a/tensorflow/core/lib/monitoring/BUILD b/tensorflow/core/lib/monitoring/BUILD
new file mode 100644
index 00000000000..35c59079231
--- /dev/null
+++ b/tensorflow/core/lib/monitoring/BUILD
@@ -0,0 +1,195 @@
+package(
+    default_visibility = [
+        # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/**
+        "//tensorflow/core:__pkg__",
+        # tensorflow/core/platform:monitoring depends on this package
+        "//tensorflow/core/platform:__subpackages__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+# Todo(bmzhao): Remaining targets to add are: all tests.
+
+cc_library(
+    name = "collected_metrics",
+    hdrs = ["collected_metrics.h"],
+    deps = [
+        ":metric_def",
+        "//tensorflow/core/framework:summary_proto_cc",
+    ],
+)
+
+cc_library(
+    name = "collection_registry",
+    srcs = ["collection_registry.cc"],
+    hdrs = ["collection_registry.h"],
+    deps = [
+        ":collected_metrics",
+        ":metric_def",
+        "//tensorflow/core/framework:summary_proto_cc",
+        "//tensorflow/core/platform:env",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:stringpiece",
+        "//tensorflow/core/platform:thread_annotations",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
+cc_library(
+    name = "counter",
+    hdrs = ["counter.h"],
+    deps = [
+        ":collection_registry",
+        ":metric_def",
+        ":mobile_counter",
+        "//tensorflow/core/lib/core:status",
+        "//tensorflow/core/platform",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:thread_annotations",
+    ],
+)
+
+cc_library(
+    name = "gauge",
+    hdrs = ["gauge.h"],
+    deps = [
+        ":collection_registry",
+        ":metric_def",
+        ":mobile_gauge",
+        "//tensorflow/core/lib/core:status",
+        "//tensorflow/core/platform",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:thread_annotations",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
+cc_library(
+    name = "metric_def",
+    hdrs = ["metric_def.h"],
+    deps = [
+        "//tensorflow/core/framework:summary_proto_cc",
+        "//tensorflow/core/platform:stringpiece",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
+cc_library(
+    name = "mobile_counter",
+    hdrs = ["mobile_counter.h"],
+    deps = [
+        "//tensorflow/core/lib/core:status",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
+cc_library(
+    name = "mobile_gauge",
+    hdrs = ["mobile_gauge.h"],
+    deps = [
+        "//tensorflow/core/lib/core:status",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
+cc_library(
+    name = "mobile_sampler",
+    hdrs = ["mobile_sampler.h"],
+    deps = [
+        ":metric_def",
+        "//tensorflow/core/framework:summary_proto_cc",
+        "//tensorflow/core/lib/core:status",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
+cc_library(
+    name = "sampler",
+    srcs = ["sampler.cc"],
+    hdrs = ["sampler.h"],
+    deps = [
+        ":collection_registry",
+        ":metric_def",
+        ":mobile_sampler",
+        "//tensorflow/core/framework:summary_proto_cc",
+        "//tensorflow/core/lib/core:status",
+        "//tensorflow/core/lib/histogram",
+        "//tensorflow/core/platform",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:thread_annotations",
+    ],
+)
+
+filegroup(
+    name = "legacy_lib_monitoring_lib_headers",
+    srcs = [
+        "collected_metrics.h",
+        "collection_registry.h",
+        "counter.h",
+        "gauge.h",
+        "metric_def.h",
+        "sampler.h",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_monitoring_lib_internal_public_headers",
+    srcs = [
+        "mobile_counter.h",
+        "mobile_gauge.h",
+        "mobile_sampler.h",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_monitoring_all_headers",
+    srcs = [
+        "collected_metrics.h",
+        "collection_registry.h",
+        "counter.h",
+        "gauge.h",
+        "metric_def.h",
+        "mobile_counter.h",
+        "mobile_gauge.h",
+        "mobile_sampler.h",
+        "sampler.h",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_monitoring_all_srcs",
+    srcs = [
+        "collection_registry.cc",
+        "sampler.cc",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+# Note(bmzhao): Ideally we would use a filegroup to represent these tests instead.
+# However, that causes tf_cc_tests to link all of these tests into a single object
+# file. This breaks collection_registry_test, because sample_test.cc has static variables
+# that instantiate metrics with the same names that collection_registry_test tries
+# to create ("/tensorflow/test/sampler_with_labels" and
+# "/tensorflow/test/sampler_without_labels").
+exports_files(
+    [
+        "collection_registry_test.cc",
+        "counter_test.cc",
+        "gauge_test.cc",
+        "metric_def_test.cc",
+        "sampler_test.cc",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index 9e4e1989dd8..b3db7079d12 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -20,13 +20,13 @@ limitations under the License.
 #include 
 
 #include "tensorflow/core/framework/summary.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
 #include "tensorflow/core/lib/monitoring/metric_def.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
 
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index bc4365e439c..84b915f360c 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -20,7 +20,7 @@ limitations under the License.
 #include 
 
 #include "tensorflow/core/framework/summary.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {

From b543321f27623ebac2cc34b24a0c1d17ca1ee560 Mon Sep 17 00:00:00 2001
From: Denis Khalikov 
Date: Thu, 28 Nov 2019 13:27:26 -0800
Subject: [PATCH 100/279] [spirv] Check that operand of
 `spirv::CompositeExtractOp` is constant while folding.

Closes #281

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/281 from denis0x0D:sandbox/composite_ex_fold d02d73658bd1b9eaa515eb4e0aee34bc41d4252b
PiperOrigin-RevId: 282971563
Change-Id: I3089600173dec453efbf9134f8d526cf0d58f6a5
---
 third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index e82420022ea..e8896fac526 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -336,6 +336,9 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
 // `indices`. Returns a null Attribute if error happens.
 static Attribute extractCompositeElement(Attribute composite,
                                          ArrayRef indices) {
+  // Check that given composite is a constant.
+  if (!composite)
+    return {};
   // Return composite itself if we reach the end of the index chain.
   if (indices.empty())
     return composite;

From 4596b3bc1ea85ea1bfdb6c0779b90eaab65ae252 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Thu, 28 Nov 2019 23:17:10 -0800
Subject: [PATCH 101/279] Remove dependence on the simple redirection header
 platform/cuda.h

This change will help make cuda code not depend on core/platform, paving the way to make core/platform cuda-agnostic.

PiperOrigin-RevId: 283011526
Change-Id: I4a4042ff8c1fa647405e28fa08115b59b3664c17
---
 tensorflow/core/common_runtime/gpu/gpu_device.cc                | 2 +-
 tensorflow/core/kernels/check_numerics_op.cc                    | 2 +-
 tensorflow/core/kernels/crop_and_resize_op.cc                   | 2 +-
 tensorflow/core/kernels/cuda_solvers.cc                         | 2 +-
 tensorflow/core/kernels/segment_reduction_ops_impl.h            | 2 +-
 tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc | 2 +-
 .../kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc     | 2 +-
 tensorflow/core/kernels/where_op.cc                             | 2 +-
 tensorflow/core/nccl/nccl_manager.cc                            | 2 +-
 9 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 0e230a5d2bd..2287bf889ab 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -61,7 +61,7 @@ limitations under the License.
 #include "tensorflow/core/lib/strings/strcat.h"
 #if GOOGLE_CUDA
 #include "third_party/gpus/cudnn/cudnn.h"
-#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 #elif TENSORFLOW_USE_ROCM
 #include "tensorflow/core/platform/rocm.h"
 #endif
diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc
index 63c52a11624..17cf5c37ae9 100644
--- a/tensorflow/core/kernels/check_numerics_op.cc
+++ b/tensorflow/core/kernels/check_numerics_op.cc
@@ -31,7 +31,7 @@ limitations under the License.
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #if GOOGLE_CUDA
-#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 #elif TENSORFLOW_USE_ROCM
 #include "tensorflow/core/platform/rocm.h"
 #endif
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index 892fadd51dd..5223501997e 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -40,7 +40,7 @@ limitations under the License.
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #if GOOGLE_CUDA
-#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 using stream_executor::cuda::ScopedActivateExecutorContext;
 #elif TENSORFLOW_USE_ROCM
 #include "tensorflow/core/platform/rocm.h"
diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc
index 9abf5439571..1c569204265 100644
--- a/tensorflow/core/kernels/cuda_solvers.cc
+++ b/tensorflow/core/kernels/cuda_solvers.cc
@@ -30,10 +30,10 @@
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/platform/cuda.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 
 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
 // casting away constness on our data, we instead reinterpret the CuBLAS
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h
index 5aa05faab97..a472655d3e0 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl.h
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h
@@ -45,7 +45,7 @@ limitations under the License.
 
 #if GOOGLE_CUDA
 #include "tensorflow/core/kernels/cuda_solvers.h"
-#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 
 using stream_executor::cuda::ScopedActivateExecutorContext;
 #elif TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc
index 94cbae3185f..6e0397c8d27 100644
--- a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc
+++ b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc
@@ -36,7 +36,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
 #include "tensorflow/core/kernels/cuda_solvers.h"
 #include "tensorflow/core/kernels/cuda_sparse.h"
-#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 
 using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
 #endif
diff --git a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc
index f791b6c5105..3ecebfe0ac7 100644
--- a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc
+++ b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc
@@ -34,7 +34,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
 #include "tensorflow/core/kernels/cuda_solvers.h"
 #include "tensorflow/core/kernels/cuda_sparse.h"
-#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 
 using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
 #endif
diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc
index 2030512bfd2..318894bfce4 100644
--- a/tensorflow/core/kernels/where_op.cc
+++ b/tensorflow/core/kernels/where_op.cc
@@ -41,7 +41,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
 #include "tensorflow/core/kernels/cuda_solvers.h"
 #if GOOGLE_CUDA
-#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 using stream_executor::cuda::ScopedActivateExecutorContext;
 #elif TENSORFLOW_USE_ROCM
 #include "tensorflow/core/platform/rocm.h"
diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc
index 60dea1a215c..aadd2a00f3c 100644
--- a/tensorflow/core/nccl/nccl_manager.cc
+++ b/tensorflow/core/nccl/nccl_manager.cc
@@ -23,7 +23,7 @@ limitations under the License.
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #if GOOGLE_CUDA
-#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 #elif TENSORFLOW_USE_ROCM
 #include "tensorflow/core/platform/rocm.h"
 #endif

From d85c2b19507b23bfc52bbc51cfc60e7eb75345bf Mon Sep 17 00:00:00 2001
From: Gaurav Jain 
Date: Fri, 29 Nov 2019 00:30:42 -0800
Subject: [PATCH 102/279] Make sync local handles avoid lock acquisition

There is no need to notify with sync handles, thus we can save on lock
acquisition.

PiperOrigin-RevId: 283017334
Change-Id: Ibee41ec76999881aeb1bf92879d7dc68a7a2e478
---
 .../core/common_runtime/eager/execute.cc      | 19 +++++-----
 .../common_runtime/eager/tensor_handle.cc     | 35 +++++++++++++------
 .../core/common_runtime/eager/tensor_handle.h |  9 ++---
 .../eager/tensor_handle_data.cc               | 28 +++++++--------
 .../common_runtime/eager/tensor_handle_data.h | 13 ++++---
 .../eager/tensor_handle_test.cc               |  6 ++--
 6 files changed, 62 insertions(+), 48 deletions(-)

diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 6ab90a0b940..6895a79f767 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -642,8 +642,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
     graph_collector = ctx->GetGraphCollector();
   }
 
+  const bool async = executor.Async();
   for (int i = 0; i < num_outputs; ++i) {
-    TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
+    TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
+        async,
         /* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
         /* op_device= */ kernel->device(),
         /* resource_device= */ kernel->OutputResourceDevice(i),
@@ -651,7 +653,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
   }
 
   Status s;
-  if (executor.Async()) {
+  if (async) {
     auto node = absl::make_unique(
         ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
         graph_collector, output_dtypes, op->GetCancellationManager(),
@@ -1100,9 +1102,9 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
                               EagerExecutor* executor, Device* dstd,
                               TensorHandle** result) {
   TF_RETURN_IF_ERROR(executor->status());
-  TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
-      ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype, ctx,
-      result));
+  TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
+      true, ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype,
+      ctx, result));
 
   // Note that `h` may not be currently ready. However execution order will
   // make sure that `h` is ready before the copy is actually done.
@@ -1150,10 +1152,9 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
     }
     uint64 recv_op_id = 0;
     if (recver_is_local) {
-      TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
-          /* d= */ device,
-          /* op_device= */ device, /*resource_device=*/nullptr, h->dtype, ctx,
-          result));
+      TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
+          true, /* d= */ device, /* op_device= */ device,
+          /*resource_device=*/nullptr, h->dtype, ctx, result));
     } else {
       uint64 context_id = ctx->GetContextId();
       string remote_task;
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index a40686c457f..9bda0512b3d 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -131,6 +131,7 @@ TensorHandle::TensorHandle(std::unique_ptr t,
 #endif
       ctx_(ctx),
       is_remote_(false),
+      is_async_(false),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
   // Notify immediately since this handle is already ready.
@@ -150,6 +151,7 @@ TensorHandle::TensorHandle(std::unique_ptr t,
 #endif
       ctx_(ctx),
       is_remote_(false),
+      is_async_(false),
       handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
@@ -157,18 +159,19 @@ TensorHandle::TensorHandle(std::unique_ptr t,
   is_ready_notification_.Notify();
 }
 
-Status TensorHandle::CreateAsyncLocalHandle(Device* d, Device* op_device,
+Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
+                                            Device* op_device,
                                             Device* resource_device,
                                             DataType dtype, EagerContext* ctx,
                                             TensorHandle** h) {
-  *h = new TensorHandle(absl::make_unique(), d,
-                        op_device, resource_device, dtype, ctx);
+  *h = new TensorHandle(absl::make_unique(), async,
+                        d, op_device, resource_device, dtype, ctx);
 
   return Status::OK();
 }
 
-TensorHandle::TensorHandle(std::unique_ptr t,
-                           Device* d, Device* op_device,
+TensorHandle::TensorHandle(std::unique_ptr t,
+                           bool async, Device* d, Device* op_device,
                            Device* resource_device, DataType dtype,
                            EagerContext* ctx)
     : dtype(dtype),
@@ -181,9 +184,13 @@ TensorHandle::TensorHandle(std::unique_ptr t,
 #endif
       ctx_(ctx),
       is_remote_(false),
+      is_async_(async),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Async Local TensorHandle: " << this
            << " device: " << device_;
+  if (!async) {
+    is_ready_notification_.Notify();
+  }
 }
 
 #if !defined(IS_MOBILE_PLATFORM)
@@ -219,6 +226,7 @@ TensorHandle::TensorHandle(std::unique_ptr t,
       remote_output_num_(t->output_num()),
       ctx_(ctx),
       is_remote_(true),
+      is_async_(false),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Remote TensorHandle: " << this
            << " device: " << device_;
@@ -255,6 +263,7 @@ TensorHandle::TensorHandle(std::unique_ptr t,
       remote_context_id_(t->context_id()),
       ctx_(ctx),
       is_remote_(true),
+      is_async_(true),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
            << " device: " << device_;
@@ -262,7 +271,7 @@ TensorHandle::TensorHandle(std::unique_ptr t,
 #endif
 
 bool TensorHandle::IsReady() {
-  return is_ready_notification_.HasBeenNotified();
+  return !is_async_ || is_ready_notification_.HasBeenNotified();
 }
 
 Status TensorHandle::WaitReady(const char* caller) {
@@ -547,7 +556,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
 
 Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
   DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
-  DCHECK(!is_ready_notification_.HasBeenNotified())
+  DCHECK(!is_async_ || !is_ready_notification_.HasBeenNotified())
       << "SetTensor is only called on non-ready handles.";
 
   DVLOG(3) << "SetTensor on TensorHandle: " << this;
@@ -557,19 +566,23 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
     handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
   }
   tensor_handle_data_ = absl::make_unique(tensor);
-  is_poisoned_ = Status::OK();
-  is_ready_notification_.Notify();
+  if (is_async_) {
+    is_poisoned_ = Status::OK();
+    is_ready_notification_.Notify();
+  }
   return Status::OK();
 }
 
 void TensorHandle::Poison(Status status) {
-  DCHECK(!is_ready_notification_.HasBeenNotified())
+  DCHECK(!is_async_ || !is_ready_notification_.HasBeenNotified())
       << "Poison(status) can only be called on non-ready handle: " << this;
 
   DVLOG(3) << "Poison on TensorHandle: " << this;
 
   is_poisoned_ = status;
-  is_ready_notification_.Notify();
+  if (is_async_ || is_remote_) {
+    is_ready_notification_.Notify();
+  }
 }
 
 Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 7372885ed74..f61d3d27951 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -67,9 +67,9 @@ class TensorHandle : public core::RefCounted {
   TensorHandle(std::unique_ptr t,
                const ResourceHandle& resource_handle, Device* d,
                Device* op_device, EagerContext* ctx);
-  TensorHandle(std::unique_ptr t, Device* d,
-               Device* op_device, Device* resource_device, DataType dtype,
-               EagerContext* ctx);
+  TensorHandle(std::unique_ptr t, bool async,
+               Device* d, Device* op_device, Device* resource_device,
+               DataType dtype, EagerContext* ctx);
 
 #if !defined(IS_MOBILE_PLATFORM)
   TensorHandle(std::unique_ptr t, DataType dtype,
@@ -87,7 +87,7 @@ class TensorHandle : public core::RefCounted {
   static Status CreateLocalHandle(const class Tensor& t, Device* d,
                                   Device* op_device, EagerContext* ctx,
                                   TensorHandle** h);
-  static Status CreateAsyncLocalHandle(Device* d, Device* op_device,
+  static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device,
                                        Device* resource_device, DataType dtype,
                                        EagerContext* ctx, TensorHandle** h);
 #if !defined(IS_MOBILE_PLATFORM)
@@ -271,6 +271,7 @@ class TensorHandle : public core::RefCounted {
   // WaitReady() has returned. At that point, is_poisoned_ is immutable.
   Status is_poisoned_;
   const bool is_remote_;
+  const bool is_async_;
 
   // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
   // refers to a remote resource handle, we store data types and shapes for
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc
index 4fb44269584..d718e39687f 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc
@@ -58,44 +58,44 @@ Status LocalTensorHandleData::NumElements(int64* num_elements) const {
   return Status::OK();
 }
 
-Status AsyncLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
+Status EmptyLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
   return errors::Unavailable(
-      "Unable to get a tensor for an async handle. "
+      "Unable to get a tensor for an empty handle. "
       "Please wait until it is ready");
 }
 
-Status AsyncLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
+Status EmptyLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
   return errors::Unavailable(
-      "Unable to get a tensor for an async handle. "
+      "Unable to get a tensor for an empty handle. "
       "Please wait until it is ready");
 }
 
-Status AsyncLocalTensorHandleData::Shape(TensorShape* shape) const {
+Status EmptyLocalTensorHandleData::Shape(TensorShape* shape) const {
   return errors::Unavailable(
-      "Unable to get shape information for an async handle. "
+      "Unable to get shape information for an empty handle. "
       "Please wait until it is ready");
 }
 
-Status AsyncLocalTensorHandleData::NumDims(int* num_dims) const {
+Status EmptyLocalTensorHandleData::NumDims(int* num_dims) const {
   return errors::Unavailable(
-      "Unable to get shape information for an async handle. "
+      "Unable to get shape information for an empty handle. "
       "Please wait until it is ready");
 }
 
-Status AsyncLocalTensorHandleData::Dim(int dim_index, int64* dim) const {
+Status EmptyLocalTensorHandleData::Dim(int dim_index, int64* dim) const {
   return errors::Unavailable(
-      "Unable to get shape information for an async handle. "
+      "Unable to get shape information for an empty handle. "
       "Please wait until it is ready");
 }
 
-Status AsyncLocalTensorHandleData::NumElements(int64* num_elements) const {
+Status EmptyLocalTensorHandleData::NumElements(int64* num_elements) const {
   return errors::Unavailable(
-      "Unable to get shape information for an async handle. "
+      "Unable to get shape information for an empty handle. "
       "Please wait until it is ready");
 }
 
-string AsyncLocalTensorHandleData::DebugString() const {
-  return "AsyncLocalTensorHandleData";
+string EmptyLocalTensorHandleData::DebugString() const {
+  return "EmptyLocalTensorHandleData";
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.h b/tensorflow/core/common_runtime/eager/tensor_handle_data.h
index c9be6592426..e50200277f1 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle_data.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.h
@@ -58,15 +58,14 @@ class LocalTensorHandleData : public TensorHandleData {
   tensorflow::Tensor tensor_;
 };
 
-// Async Local Tensor Handle: A non-ready local tensor handle used in async
-// eager execution. Once the execution is complete this is replaced by a local
-// tensor handle.
-class AsyncLocalTensorHandleData : public TensorHandleData {
+// Empty Local Tensor Handle: Once the execution is complete this is replaced by
+// a local tensor handle.
+class EmptyLocalTensorHandleData : public TensorHandleData {
  public:
-  AsyncLocalTensorHandleData() {}
-  ~AsyncLocalTensorHandleData() override {}
+  EmptyLocalTensorHandleData() {}
+  ~EmptyLocalTensorHandleData() override {}
 
-  // Async tensor handles are not ready and hence cannot satisfy any of these
+  // Empty tensor handles are not ready and hence cannot satisfy any of these
   // requests.
   Status Tensor(const tensorflow::Tensor** t) const override;
   Status TensorValue(tensorflow::TensorValue* t) override;
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
index 05801d6e564..d8217e85315 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
@@ -33,9 +33,9 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
   TensorHandle* sync_th;
   EXPECT_TRUE(TensorHandle::CreateLocalHandle(t, &sync_th).ok());
   TensorHandle* async_th;
-  EXPECT_TRUE(TensorHandle::CreateAsyncLocalHandle(nullptr, nullptr, nullptr,
-                                                   DataType::DT_UINT16, nullptr,
-                                                   &async_th)
+  EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(true, nullptr, nullptr,
+                                                   nullptr, DataType::DT_UINT16,
+                                                   nullptr, &async_th)
                   .ok());
 
   EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());

From d8127298975c52ee6f49df86d72295e1240df661 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Fri, 29 Nov 2019 01:02:56 -0800
Subject: [PATCH 103/279] compat: Update forward compatibility horizon to
 2019-11-29

PiperOrigin-RevId: 283020698
Change-Id: Ic9a76643a9d44884c084806423d5c4b21f093246
---
 tensorflow/python/compat/compat.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 1f75de96001..fdf52af0ee5 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 11, 28)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 11, 29)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 

From 99cb9fd68d1b8c6f0857c0775c366234537bacf5 Mon Sep 17 00:00:00 2001
From: Chao Mei 
Date: Fri, 29 Nov 2019 01:34:16 -0800
Subject: [PATCH 104/279] Correct the exact meaning of memory stats collected
 from the underlying system, and only output max RSS to avoid confusion.

PiperOrigin-RevId: 283023937
Change-Id: I00b76b6c0e8dc059c3cc22a92f76822d47c8e130
---
 tensorflow/lite/profiling/memory_info.cc      | 16 +++++++--------
 tensorflow/lite/profiling/memory_info.h       | 20 ++++++++++++++-----
 tensorflow/lite/profiling/memory_info_test.cc | 11 +++++++---
 .../lite/tools/benchmark/benchmark_model.cc   |  9 +++++++--
 4 files changed, 37 insertions(+), 19 deletions(-)

diff --git a/tensorflow/lite/profiling/memory_info.cc b/tensorflow/lite/profiling/memory_info.cc
index 9e658f65408..39f94f0d250 100644
--- a/tensorflow/lite/profiling/memory_info.cc
+++ b/tensorflow/lite/profiling/memory_info.cc
@@ -35,20 +35,18 @@ MemoryUsage GetMemoryUsage() {
     result.max_rss_kb = res.ru_maxrss;
   }
   const auto mem = mallinfo();
-  result.total_allocated_bytes = mem.uordblks;
+  result.total_allocated_bytes = mem.arena;
+  result.in_use_allocated_bytes = mem.uordblks;
 #endif
   return result;
 }
 
-void MemoryUsage::SummaryToStream(std::ostream* stream) const {
-  *stream << "memory usage: max resident set size = " << max_rss_kb / 1024.0
+void MemoryUsage::AllStatsToStream(std::ostream* stream) const {
+  *stream << "max resident set size = " << max_rss_kb / 1024.0
           << " MB, total malloc-ed size = "
-          << total_allocated_bytes / 1024.0 / 1024.0 << " MB";
-}
-
-void MemoryUsage::ShortSummaryToStream(std::ostream* stream) const {
-  *stream << "max_rss_mb=" << max_rss_kb / 1024.0
-          << " total_malloced_mb=" << total_allocated_bytes / 1024.0 / 1024.0;
+          << total_allocated_bytes / 1024.0 / 1024.0
+          << " MB, in-use allocated/mmapped size = "
+          << in_use_allocated_bytes / 1024.0 / 1024.0 << " MB";
 }
 
 }  // namespace memory
diff --git a/tensorflow/lite/profiling/memory_info.h b/tensorflow/lite/profiling/memory_info.h
index 370ca3d8ebf..b5bc3a07cf7 100644
--- a/tensorflow/lite/profiling/memory_info.h
+++ b/tensorflow/lite/profiling/memory_info.h
@@ -26,21 +26,30 @@ struct MemoryUsage {
   static const int kValueNotSet;
 
   MemoryUsage()
-      : max_rss_kb(kValueNotSet), total_allocated_bytes(kValueNotSet) {}
+      : max_rss_kb(kValueNotSet),
+        total_allocated_bytes(kValueNotSet),
+        in_use_allocated_bytes(kValueNotSet) {}
 
   // The maximum memory size (in kilobytes) occupied by an OS process that is
   // held in main memory (RAM). Such memory usage information is generally
   // referred as resident set size (rss). This is an alias to rusage::ru_maxrss.
   int64_t max_rss_kb;
 
-  // Total allocated space in bytes. This is an alias to mallinfo::uordblks.
+  // Total non-mmapped space allocated from system in bytes. This is an alias to
+  // mallinfo::arena.
   int total_allocated_bytes;
 
+  // Total allocated (including mmapped) bytes that's in use (i.e. excluding
+  // those are freed). This is an alias to mallinfo::uordblks.
+  int in_use_allocated_bytes;
+
   MemoryUsage operator+(MemoryUsage const& obj) const {
     MemoryUsage res;
     res.max_rss_kb = max_rss_kb + obj.max_rss_kb;
     res.total_allocated_bytes =
         total_allocated_bytes + obj.total_allocated_bytes;
+    res.in_use_allocated_bytes =
+        in_use_allocated_bytes + obj.in_use_allocated_bytes;
     return res;
   }
 
@@ -49,15 +58,16 @@ struct MemoryUsage {
     res.max_rss_kb = max_rss_kb - obj.max_rss_kb;
     res.total_allocated_bytes =
         total_allocated_bytes - obj.total_allocated_bytes;
+    res.in_use_allocated_bytes =
+        in_use_allocated_bytes - obj.in_use_allocated_bytes;
     return res;
   }
 
-  void SummaryToStream(std::ostream* stream) const;
-  void ShortSummaryToStream(std::ostream* stream) const;
+  void AllStatsToStream(std::ostream* stream) const;
 
   friend std::ostream& operator<<(std::ostream& stream,
                                   const MemoryUsage& obj) {
-    obj.SummaryToStream(&stream);
+    obj.AllStatsToStream(&stream);
     return stream;
   }
 };
diff --git a/tensorflow/lite/profiling/memory_info_test.cc b/tensorflow/lite/profiling/memory_info_test.cc
index de595a2b2f1..5a359134160 100644
--- a/tensorflow/lite/profiling/memory_info_test.cc
+++ b/tensorflow/lite/profiling/memory_info_test.cc
@@ -25,23 +25,28 @@ TEST(MemoryUsage, AddAndSub) {
   MemoryUsage mem1, mem2;
   mem1.max_rss_kb = 5;
   mem1.total_allocated_bytes = 7000;
+  mem1.in_use_allocated_bytes = 2000;
 
   mem2.max_rss_kb = 3;
-  mem2.total_allocated_bytes = 5000;
+  mem2.total_allocated_bytes = 7000;
+  mem2.in_use_allocated_bytes = 4000;
 
   const auto add_mem = mem1 + mem2;
   EXPECT_EQ(8, add_mem.max_rss_kb);
-  EXPECT_EQ(12000, add_mem.total_allocated_bytes);
+  EXPECT_EQ(14000, add_mem.total_allocated_bytes);
+  EXPECT_EQ(6000, add_mem.in_use_allocated_bytes);
 
   const auto sub_mem = mem1 - mem2;
   EXPECT_EQ(2, sub_mem.max_rss_kb);
-  EXPECT_EQ(2000, sub_mem.total_allocated_bytes);
+  EXPECT_EQ(0, sub_mem.total_allocated_bytes);
+  EXPECT_EQ(-2000, sub_mem.in_use_allocated_bytes);
 }
 
 TEST(MemoryUsage, GetMemoryUsage) {
   MemoryUsage result;
   EXPECT_EQ(MemoryUsage::kValueNotSet, result.max_rss_kb);
   EXPECT_EQ(MemoryUsage::kValueNotSet, result.total_allocated_bytes);
+  EXPECT_EQ(MemoryUsage::kValueNotSet, result.in_use_allocated_bytes);
 
 #ifdef __linux__
   // Just allocate some space in heap so that we could meaningful memory usage
diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc
index f7ce6d86ab3..6c3fccc5e22 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc
@@ -192,8 +192,13 @@ TfLiteStatus BenchmarkModel::Run() {
                              inference_time_us, init_mem_usage,
                              overall_mem_usage});
 
-  TFLITE_LOG(INFO) << "Init " << init_mem_usage << std::endl
-                   << "Overall " << overall_mem_usage;
+  TFLITE_LOG(INFO)
+      << "Note: as the benchmark tool itself affects memory footprint, the "
+         "following is only APPROXIMATE to the actual memory footprint of the "
+         "model at runtime. Take the information at your discretion.";
+  TFLITE_LOG(INFO) << "Peak memory footprint (MB): init="
+                   << init_mem_usage.max_rss_kb / 1024.0
+                   << " overall=" << overall_mem_usage.max_rss_kb / 1024.0;
 
   return status;
 }

From 764a3ab93ac7425b49b9c13dc151bc9c2f2badf6 Mon Sep 17 00:00:00 2001
From: Xunkai Zhang 
Date: Fri, 29 Nov 2019 07:13:52 -0800
Subject: [PATCH 105/279] Improve the format in Support Library examples.

PiperOrigin-RevId: 283055924
Change-Id: I89b3062cb0bc07afc888d1b7774cf5d54aeda2e0
---
 tensorflow/lite/experimental/support/java/README.md | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/tensorflow/lite/experimental/support/java/README.md b/tensorflow/lite/experimental/support/java/README.md
index bc00123af70..d5f3e121f3a 100644
--- a/tensorflow/lite/experimental/support/java/README.md
+++ b/tensorflow/lite/experimental/support/java/README.md
@@ -114,8 +114,9 @@ try{
 }
 
 // Running inference
-if(null != tflite)
+if(null != tflite) {
     tflite.run(tImage.getBuffer(), probabilityBuffer.getBuffer());
+}
 ```
 
 ### Accessing the result
@@ -138,9 +139,9 @@ import org.tensorflow.lite.support.common.FileUtil;
 final String ASSOCIATED_AXIS_LABELS = "labels.txt";
 List associatedAxisLabels = null;
 
-try{
+try {
     associatedAxisLabels = FileUtil.loadLabels(this, ASSOCIATED_AXIS_LABELS);
-} catch (IOException e){
+} catch (IOException e) {
     Log.e("tfliteSupport", "Error reading label file", e);
 }
 ```
@@ -192,11 +193,11 @@ int size = height > width ? width : height;
 ImageProcessor imageProcessor =
     new ImageProcessor.Builder()
         // Center crop the image to the largest square possible
-        .add(new ResizeWithCropOrPadOp(size , size))
+        .add(new ResizeWithCropOrPadOp(size, size))
         // Resize using Bilinear or Nearest neighbour
         .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR));
         // Rotation counter-clockwise in 90 degree increments
-        .add(new Rot90Op(rotateDegrees/90))
+        .add(new Rot90Op(rotateDegrees / 90))
         .build();
 ```
 
@@ -229,5 +230,5 @@ TensorProcessor probabilityProcessor =
     new TensorProcessor.Builder().add(new NormalizeOp(0, 255)).build();
 
 // Post-processor which dequantize the result
-TensorBuffer dequantizedBuffer = probabilityProcessor.process(probabilityBuffer)
+TensorBuffer dequantizedBuffer = probabilityProcessor.process(probabilityBuffer);
 ```

From 67328cf44fb8f8442c94b7ec80b48854a577a795 Mon Sep 17 00:00:00 2001
From: JKIsaacLee <51275047+JKIsaacLee@users.noreply.github.com>
Date: Fri, 29 Nov 2019 08:48:22 -0800
Subject: [PATCH 106/279] Fixed typo in Ch-1 of Toy tutorial

Closes #282

PiperOrigin-RevId: 283064785
Change-Id: Ifac8f03310a4548f6f45399e5b2a2693bd7b0fe6
---
 third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md
index b8beff8d3f5..cb7f97cb3f6 100644
--- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md
+++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md
@@ -62,7 +62,7 @@ def main() {
   var b<2, 3> = [1, 2, 3, 4, 5, 6];
 
   # transpose() and print() are the only builtin, the following will transpose
-  # b and perform an element-wise multiplication before printing the result.
+  # a and b and perform an element-wise multiplication before printing the result.
   print(transpose(a) * transpose(b));
 }
 ```

From 997c0893f24963a959da129594be9aca0fd8c7e4 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar 
Date: Fri, 29 Nov 2019 10:26:23 -0800
Subject: [PATCH 107/279] Fix redundant convert and use NamedAttributeList as
 value

* Had leftover call that would result in converting to dictionary attr before
  being implicitedly converted back to NamedAttributeList;
* NamedAttributeList is value typed, so don't use const reference;

PiperOrigin-RevId: 283072576
Change-Id: If69809291a918378fefa99fde14e491150eea726
---
 third_party/mlir/include/mlir/IR/Operation.h |  4 ++--
 third_party/mlir/lib/IR/Operation.cpp        | 12 ++++++------
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/third_party/mlir/include/mlir/IR/Operation.h b/third_party/mlir/include/mlir/IR/Operation.h
index 8a7bad13d69..27bc1b17b63 100644
--- a/third_party/mlir/include/mlir/IR/Operation.h
+++ b/third_party/mlir/include/mlir/IR/Operation.h
@@ -63,7 +63,7 @@ public:
   static Operation *create(Location location, OperationName name,
                            ArrayRef resultTypes,
                            ArrayRef operands,
-                           const NamedAttributeList &attributes,
+                           NamedAttributeList attributes,
                            ArrayRef successors, unsigned numRegions,
                            bool resizableOperandList);
 
@@ -74,7 +74,7 @@ public:
   static Operation *create(Location location, OperationName name,
                            ArrayRef resultTypes,
                            ArrayRef operands,
-                           const NamedAttributeList &attributes,
+                           NamedAttributeList attributes,
                            ArrayRef successors = {},
                            ArrayRef> regions = {},
                            bool resizableOperandList = false);
diff --git a/third_party/mlir/lib/IR/Operation.cpp b/third_party/mlir/lib/IR/Operation.cpp
index e5ec43c699b..f0ebd59ab9f 100644
--- a/third_party/mlir/lib/IR/Operation.cpp
+++ b/third_party/mlir/lib/IR/Operation.cpp
@@ -125,17 +125,17 @@ Operation *Operation::create(Location location, OperationName name,
 
 /// Create a new Operation from operation state.
 Operation *Operation::create(const OperationState &state) {
-  return Operation::create(
-      state.location, state.name, state.types, state.operands,
-      NamedAttributeList(state.attributes).getDictionary(), state.successors,
-      state.regions, state.resizableOperandList);
+  return Operation::create(state.location, state.name, state.types,
+                           state.operands, NamedAttributeList(state.attributes),
+                           state.successors, state.regions,
+                           state.resizableOperandList);
 }
 
 /// Create a new Operation with the specific fields.
 Operation *Operation::create(Location location, OperationName name,
                              ArrayRef resultTypes,
                              ArrayRef operands,
-                             const NamedAttributeList &attributes,
+                             NamedAttributeList attributes,
                              ArrayRef successors,
                              ArrayRef> regions,
                              bool resizableOperandList) {
@@ -153,7 +153,7 @@ Operation *Operation::create(Location location, OperationName name,
 Operation *Operation::create(Location location, OperationName name,
                              ArrayRef resultTypes,
                              ArrayRef operands,
-                             const NamedAttributeList &attributes,
+                             NamedAttributeList attributes,
                              ArrayRef successors, unsigned numRegions,
                              bool resizableOperandList) {
   unsigned numSuccessors = successors.size();

From 427c708efae84e491a92d27c3edcaa1b15109d83 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar 
Date: Fri, 29 Nov 2019 10:43:24 -0800
Subject: [PATCH 108/279] mlir-tblgen: Dump input records when no generator is
 set

Follow LLVM's tblgen convention when no generator is set instead of asserting.

PiperOrigin-RevId: 283073690
Change-Id: Iaa7e6f9fd282f95a2300598dde9dd83a9e497f0a
---
 third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp b/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
index 50b680d904d..993a05d7095 100644
--- a/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
@@ -74,7 +74,10 @@ const mlir::GenInfo *generator;
 // TableGenMain requires a function pointer so this function is passed in which
 // simply wraps the call to the generator.
 static bool MlirTableGenMain(raw_ostream &os, RecordKeeper &records) {
-  assert(generator && "no generator specified");
+  if (!generator) {
+    os << records;
+    return false;
+  }
   return generator->invoke(records, os);
 }
 

From 14a0c12dc2eb6398414df4b92d47455ac37ff4fd Mon Sep 17 00:00:00 2001
From: Gaurav Jain 
Date: Fri, 29 Nov 2019 15:54:44 -0800
Subject: [PATCH 109/279] Remove some stale forward compatibility dates

PiperOrigin-RevId: 283092973
Change-Id: Ia708a9c04a032e1222c7d56ad4936e263424fbdd
---
 tensorflow/python/eager/function.py | 26 +++-----------------------
 tensorflow/python/ops/math_grad.py  | 22 ++++++++--------------
 tensorflow/python/ops/math_ops.py   | 11 ++---------
 3 files changed, 13 insertions(+), 46 deletions(-)

diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 07619c882e5..63263c03a97 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -34,7 +34,6 @@ from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import function_pb2
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python import _pywrap_utils
-from tensorflow.python.compat import compat as fwd_compat
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import backprop_util
 from tensorflow.python.eager import context
@@ -1030,18 +1029,8 @@ class _TapeGradientFunctions(object):
         with ops.get_default_graph()._override_gradient_function(  # pylint: disable=protected-access
             {"PartitionedCall": gradient_function,
              "StatefulPartitionedCall": gradient_function}):
-          # Previously, we relyed on "_gradient_op_type" attribute to restore a
-          # function gradient in function_deserialization.py, So add a dummy
-          # value "PartitionedCallUnused" for the forward compatibility.
-          if fwd_compat.forward_compatible(2019, 11, 16):
-            forward_outputs = forward_function.call(context.context(),
-                                                    forward_inputs)
-          else:
-            with ops.get_default_graph().gradient_override_map(
-                {"PartitionedCall": "PartitionedCallUnused",
-                 "StatefulPartitionedCall": "PartitionedCallUnused"}):
-              forward_outputs = forward_function.call(context.context(),
-                                                      forward_inputs)
+          forward_outputs = forward_function.call(context.context(),
+                                                  forward_inputs)
         py_backward, _ = self._wrap_backward_function(
             self._func_graph, backward_function, forward_outputs)
       # We will never request backward tape gradients for this operation
@@ -1703,16 +1692,7 @@ class ConcreteFunction(object):
       with ops.get_default_graph()._override_gradient_function(  # pylint: disable=protected-access
           {"PartitionedCall": self._get_gradient_function(),
            "StatefulPartitionedCall": self._get_gradient_function()}):
-        # Previously, we relyed on "_gradient_op_type" attribute to restore a
-        # function gradient in function_deserialization.py. So add a dummy
-        # value "PartitionedCallUnused" for the forward compatibility.
-        if fwd_compat.forward_compatible(2019, 11, 16):
-          flat_outputs = forward_function.call(ctx, args_with_tangents)
-        else:
-          with ops.get_default_graph().gradient_override_map(
-              {"PartitionedCall": "PartitionedCallUnused",
-               "StatefulPartitionedCall": "PartitionedCallUnused"}):
-            flat_outputs = forward_function.call(ctx, args_with_tangents)
+        flat_outputs = forward_function.call(ctx, args_with_tangents)
     forward_backward.record(flat_outputs)
     return self._build_call_outputs(flat_outputs)
 
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index f7e01f57200..f9a75f6aecc 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -204,20 +204,14 @@ def _SumGrad(op, grad):
 
   input_shape = array_ops.shape(op.inputs[0])
 
-  if compat.forward_compatible(2019, 10, 23):
-    if not op.get_attr("keep_dims"):
-      with ops.colocate_with(input_shape):
-        # TODO(apassos) remove this once device placement for eager ops makes
-        # more sense.
-        output_shape_kept_dims = math_ops.reduced_shape(input_shape,
-                                                        op.inputs[1])
-      grad = array_ops.reshape(grad, output_shape_kept_dims)
-    return [array_ops.broadcast_to(grad, input_shape), None]
-  with ops.colocate_with(input_shape):
-    output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
-    tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
-  grad = array_ops.reshape(grad, output_shape_kept_dims)
-  return [array_ops.tile(grad, tile_scaling), None]
+  if not op.get_attr("keep_dims"):
+    with ops.colocate_with(input_shape):
+      # TODO(apassos) remove this once device placement for eager ops makes
+      # more sense.
+      output_shape_kept_dims = math_ops.reduced_shape(input_shape,
+                                                      op.inputs[1])
+    grad = array_ops.reshape(grad, output_shape_kept_dims)
+  return [array_ops.broadcast_to(grad, input_shape), None]
 
 
 def _MinOrMaxGrad(op, grad):
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 6e6fd50a419..078219e2f23 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -75,7 +75,6 @@ import six
 from six.moves import builtins
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
-from tensorflow.python.compat import compat as fwd_compat
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -1364,10 +1363,7 @@ def tensor_equals(self, other):
   g = getattr(self, "graph", None)
   if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
       (g is None or g._building_function)):  # pylint: disable=protected-access
-    if fwd_compat.forward_compatible(2019, 9, 25):
-      return gen_math_ops.equal(self, other, incompatible_shape_error=False)
-    else:
-      return gen_math_ops.equal(self, other)
+    return gen_math_ops.equal(self, other, incompatible_shape_error=False)
   else:
     # In legacy graph mode, tensor equality is object equality
     return self is other
@@ -1378,10 +1374,7 @@ def tensor_not_equals(self, other):
   if other is None:
     return True
   if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
-    if fwd_compat.forward_compatible(2019, 9, 25):
-      return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
-    else:
-      return gen_math_ops.not_equal(self, other)
+    return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
   else:
     # In legacy graph mode, tensor equality is object equality
     return self is not other

From 7879c387f8a4a236cb031aa3d54252e9f14771be Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Fri, 29 Nov 2019 21:16:44 -0800
Subject: [PATCH 110/279] Move core/lib/random/random to core/platform

PiperOrigin-RevId: 283111631
Change-Id: I8ef4e09e5035721c45a64ecc7eb598ba448b82ed
---
 tensorflow/core/lib/random/BUILD              |  5 +--
 tensorflow/core/lib/random/random.h           | 16 +--------
 tensorflow/core/platform/BUILD                | 12 +++++++
 .../core/{lib/random => platform}/random.cc   |  2 +-
 tensorflow/core/platform/random.h             | 35 +++++++++++++++++++
 5 files changed, 50 insertions(+), 20 deletions(-)
 rename tensorflow/core/{lib/random => platform}/random.cc (96%)
 create mode 100644 tensorflow/core/platform/random.h

diff --git a/tensorflow/core/lib/random/BUILD b/tensorflow/core/lib/random/BUILD
index c9d48689849..7360e72f233 100644
--- a/tensorflow/core/lib/random/BUILD
+++ b/tensorflow/core/lib/random/BUILD
@@ -64,11 +64,9 @@ cc_library(
 
 cc_library(
     name = "random",
-    srcs = ["random.cc"],
     hdrs = ["random.h"],
     deps = [
-        "//tensorflow/core/platform:mutex",
-        "//tensorflow/core/platform:types",
+        "//tensorflow/core/platform:random",
     ],
 )
 
@@ -133,7 +131,6 @@ filegroup(
     name = "legacy_lib_random_all_srcs",
     srcs = [
         "distribution_sampler.cc",
-        "random.cc",
         "random_distributions.cc",
         "simple_philox.cc",
         "weighted_picker.cc",
diff --git a/tensorflow/core/lib/random/random.h b/tensorflow/core/lib/random/random.h
index 5335c8cc3c9..e280d98d551 100644
--- a/tensorflow/core/lib/random/random.h
+++ b/tensorflow/core/lib/random/random.h
@@ -16,20 +16,6 @@ limitations under the License.
 #ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_
 #define TENSORFLOW_LIB_RANDOM_RANDOM_H_
 
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace random {
-
-// Return a 64-bit random value.  Different sequences are generated
-// in different processes.
-uint64 New64();
-
-// Return a 64-bit random value. Uses
-// std::mersenne_twister_engine::default_seed as seed value.
-uint64 New64DefaultSeed();
-
-}  // namespace random
-}  // namespace tensorflow
+#include "tensorflow/core/platform/random.h"
 
 #endif  // TENSORFLOW_LIB_RANDOM_RANDOM_H_
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 87b6fa1af1b..f743e01ba8a 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -376,6 +376,16 @@ cc_library(
     ] + if_static(["@com_google_protobuf//:protobuf"]),
 )
 
+cc_library(
+    name = "random",
+    srcs = ["random.cc"],
+    hdrs = ["random.h"],
+    deps = [
+        ":mutex",
+        ":types",
+    ],
+)
+
 cc_library(
     name = "raw_coding",
     hdrs = ["raw_coding.h"],
@@ -713,6 +723,7 @@ filegroup(
             "numbers.cc",
             "platform_strings.cc",
             "protobuf.cc",
+            "random.cc",
             "scanner.cc",
             "strcat.cc",
             "stringprintf.cc",
@@ -824,6 +835,7 @@ filegroup(
             "platform_strings.cc",
             "protobuf.cc",
             "protobuf_util.cc",
+            "random.cc",
             "scanner.cc",
             "setround.cc",
             "status.cc",
diff --git a/tensorflow/core/lib/random/random.cc b/tensorflow/core/platform/random.cc
similarity index 96%
rename from tensorflow/core/lib/random/random.cc
rename to tensorflow/core/platform/random.cc
index 82dc8295073..d7252810021 100644
--- a/tensorflow/core/lib/random/random.cc
+++ b/tensorflow/core/platform/random.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/random.h"
 
 #include 
 #include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/core/platform/random.h b/tensorflow/core/platform/random.h
new file mode 100644
index 00000000000..f605fd9e477
--- /dev/null
+++ b/tensorflow/core/platform/random.h
@@ -0,0 +1,35 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_RANDOM_H_
+#define TENSORFLOW_CORE_PLATFORM_RANDOM_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace random {
+
+// Return a 64-bit random value.  Different sequences are generated
+// in different processes.
+uint64 New64();
+
+// Return a 64-bit random value. Uses
+// std::mersenne_twister_engine::default_seed as seed value.
+uint64 New64DefaultSeed();
+
+}  // namespace random
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_RANDOM_H_

From 4f2e7e08db34fdadd4d1460888f08e163fc00a4c Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Fri, 29 Nov 2019 21:33:13 -0800
Subject: [PATCH 111/279] Remove dependence on core/lib/strings/scanner within
 core/platform

PiperOrigin-RevId: 283112406
Change-Id: I60fc38818ae34fd775929999788f48c3388dd626
---
 tensorflow/core/platform/cloud/BUILD                | 2 ++
 tensorflow/core/platform/cloud/curl_http_request.cc | 2 +-
 tensorflow/core/platform/cloud/oauth_client_test.cc | 4 +++-
 3 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 8c87b5c4bcf..85aa864b8a7 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -160,6 +160,7 @@ cc_library(
         ":http_request",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:scanner",
         "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringpiece",
         "@curl",
@@ -416,6 +417,7 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:scanner",
         "@boringssl//:crypto",
     ],
 )
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index f53fce63750..b3646eba391 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -19,10 +19,10 @@ limitations under the License.
 
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/scanner.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/scanner.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/public/version.h"
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 7b76e4c6c16..ca24365434b 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -14,16 +14,18 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/oauth_client.h"
+
 #include 
+
 #include 
 #include 
 #include 
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/base64.h"
-#include "tensorflow/core/lib/strings/scanner.h"
 #include "tensorflow/core/platform/cloud/http_request_fake.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/scanner.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {

From 74f306e3cdf653338ed40a08c38b50aed8ed810b Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Fri, 29 Nov 2019 21:46:13 -0800
Subject: [PATCH 112/279] Add more details in the tfl.pack error string.

PiperOrigin-RevId: 283113195
Change-Id: I7a69145252792b742b9d1d66152aa3d6eff713e8
---
 tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 0549eadc88a..1df8026643b 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -720,7 +720,8 @@ static LogicalResult Verify(PackOp op) {
   for (Value *operand : op.getOperands()) {
     auto other_type = operand->getType().cast();
     if (input_type != other_type)
-      return op.emitOpError("operands should be of the same type");
+      return op.emitOpError("operands should be of the same type. got ")
+             << input_type << ", " << other_type;
   }
 
   return success();

From aaea5414317793a3eedc41f23cc5e619ee5941dc Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Fri, 29 Nov 2019 22:08:00 -0800
Subject: [PATCH 113/279] Bump minimum bazel version requirement to 1.0.0

PiperOrigin-RevId: 283114685
Change-Id: Ie7160112ff379fcc7e4c4794db20f4eb24f5f8df
---
 WORKSPACE    | 2 +-
 configure.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/WORKSPACE b/WORKSPACE
index babb14b509e..48536a5d1d0 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -89,7 +89,7 @@ swift_rules_dependencies()
 # files, in case the parsing of those build files depends on the bazel
 # version we require here.
 load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
-check_bazel_version_at_least("0.19.0")
+check_bazel_version_at_least("1.0.0")
 
 load("//third_party/android:android_configure.bzl", "android_configure")
 android_configure(name="local_config_android")
diff --git a/configure.py b/configure.py
index e02428a25a2..2c7914052e9 100644
--- a/configure.py
+++ b/configure.py
@@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
 _TF_WORKSPACE_ROOT = ''
 _TF_BAZELRC = ''
 _TF_CURRENT_BAZEL_VERSION = None
-_TF_MIN_BAZEL_VERSION = '0.27.1'
+_TF_MIN_BAZEL_VERSION = '1.0.0'
 _TF_MAX_BAZEL_VERSION = '1.1.0'
 
 NCCL_LIB_PATHS = [

From 76c94f50f394a6fa210046010395472e193a4f14 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Fri, 29 Nov 2019 22:11:04 -0800
Subject: [PATCH 114/279] Remove dependence on core/lib/strings/numbers in
 core/platform

PiperOrigin-RevId: 283114863
Change-Id: Ieb2c45737840b1e6ab4563bb95ec81e2636310bb
---
 tensorflow/core/platform/cloud/BUILD              | 2 ++
 tensorflow/core/platform/cloud/gcs_file_system.cc | 2 +-
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 85aa864b8a7..a4019273fc9 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -99,6 +99,7 @@ cc_library(
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:numbers",
         "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
@@ -133,6 +134,7 @@ cc_library(
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:numbers",
         "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 8c4cb831346..c55d4ef257e 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -29,7 +29,6 @@ limitations under the License.
 #include "include/json/json.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/cloud/file_block_cache.h"
 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
@@ -39,6 +38,7 @@ limitations under the License.
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/numbers.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/stringprintf.h"

From 76f7369407e908a9904254642887d4fd09d4d88f Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Fri, 29 Nov 2019 22:12:07 -0800
Subject: [PATCH 115/279] Remove dependence on core/lib/strings/strcat in
 core/platform

PiperOrigin-RevId: 283114916
Change-Id: I2bedbef22629bf2f65ec6284839a8d5c4c3c99ec
---
 tensorflow/core/platform/env_test.cc                  | 2 +-
 tensorflow/core/platform/file_system_test.cc          | 2 +-
 tensorflow/core/platform/hadoop/BUILD                 | 1 +
 tensorflow/core/platform/hadoop/hadoop_file_system.cc | 2 +-
 4 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 06e09b1911d..bee02dbeeed 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -21,11 +21,11 @@ limitations under the License.
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/cord.h"
 #include "tensorflow/core/platform/null_file_system.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/str_util.h"
+#include "tensorflow/core/platform/strcat.h"
 #include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/test.h"
 
diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc
index 78358bd3458..8b577c37c75 100644
--- a/tensorflow/core/platform/file_system_test.cc
+++ b/tensorflow/core/platform/file_system_test.cc
@@ -19,9 +19,9 @@ limitations under the License.
 
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/null_file_system.h"
 #include "tensorflow/core/platform/str_util.h"
+#include "tensorflow/core/platform/strcat.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/hadoop/BUILD b/tensorflow/core/platform/hadoop/BUILD
index b68eb0bcd4f..dc42b901b62 100644
--- a/tensorflow/core/platform/hadoop/BUILD
+++ b/tensorflow/core/platform/hadoop/BUILD
@@ -18,6 +18,7 @@ cc_library(
     deps = [
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:strcat",
         "//third_party/hadoop:hdfs",
     ],
     alwayslink = 1,
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 60668ff4f61..6a5c4115189 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -18,7 +18,6 @@ limitations under the License.
 #include 
 
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/error.h"
 #include "tensorflow/core/platform/file_system.h"
@@ -26,6 +25,7 @@ limitations under the License.
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/strcat.h"
 #include "third_party/hadoop/hdfs.h"
 
 namespace tensorflow {

From 75924f989300472ad59db694840d45390fedc118 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Fri, 29 Nov 2019 22:49:58 -0800
Subject: [PATCH 116/279] Move tensorflow/core/lib/io/path to
 tensorflow/core/platform

PiperOrigin-RevId: 283116521
Change-Id: I739b67baa27339634c16ea4656a50bfa827793f4
---
 tensorflow/core/BUILD                         |  1 +
 tensorflow/core/lib/io/BUILD                  | 11 +--
 tensorflow/core/lib/io/path.h                 | 79 +--------------
 tensorflow/core/platform/BUILD                | 17 ++++
 .../core/platform/default/build_refactor.bzl  |  4 +-
 .../core/platform/default/posix_file_system.h |  2 +-
 tensorflow/core/platform/env.cc               |  2 +-
 tensorflow/core/platform/file_system.cc       |  2 +-
 .../core/platform/file_system_helper.cc       |  2 +-
 tensorflow/core/{lib/io => platform}/path.cc  |  2 +-
 tensorflow/core/platform/path.h               | 98 +++++++++++++++++++
 11 files changed, 125 insertions(+), 95 deletions(-)
 rename tensorflow/core/{lib/io => platform}/path.cc (99%)
 create mode 100644 tensorflow/core/platform/path.h

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 29fe1c92932..3839442c167 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2277,6 +2277,7 @@ cc_library(
         "//tensorflow/core/platform:net",
         "//tensorflow/core/platform:null_file_system",
         "//tensorflow/core/platform:numbers",
+        "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:platform_port",
         "//tensorflow/core/platform:platform_strings",
         "//tensorflow/core/platform:prefetch",
diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD
index aa4f34d45c5..123e24db3c7 100644
--- a/tensorflow/core/lib/io/BUILD
+++ b/tensorflow/core/lib/io/BUILD
@@ -4,8 +4,6 @@ package(
         "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__",
         # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/**
         "//tensorflow/core:__pkg__",
-        # tensorflow/core/platform:env uses :path
-        "//tensorflow/core/platform:__pkg__",
     ],
     licenses = ["notice"],  # Apache 2.0
 )
@@ -103,15 +101,9 @@ cc_library(
 
 cc_library(
     name = "path",
-    srcs = ["path.cc"],
     hdrs = ["path.h"],
     deps = [
-        "//tensorflow/core/platform:logging",
-        "//tensorflow/core/platform:mutex",
-        "//tensorflow/core/platform:scanner",
-        "//tensorflow/core/platform:strcat",
-        "//tensorflow/core/platform:stringpiece",
-        "//tensorflow/core/platform:types",
+        "//tensorflow/core/platform:path",
     ],
     alwayslink = True,
 )
@@ -322,7 +314,6 @@ filegroup(
         "inputbuffer.cc",
         "inputstream_interface.cc",
         "iterator.cc",
-        "path.cc",
         "random_inputstream.cc",
         "record_reader.cc",
         "record_writer.cc",
diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h
index 7cfd9809fdb..f5deacd1026 100644
--- a/tensorflow/core/lib/io/path.h
+++ b/tensorflow/core/lib/io/path.h
@@ -16,83 +16,6 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_LIB_IO_PATH_H_
 #define TENSORFLOW_CORE_LIB_IO_PATH_H_
 
-#include "tensorflow/core/platform/stringpiece.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace io {
-namespace internal {
-string JoinPathImpl(std::initializer_list paths);
-}
-
-// Utility routines for processing filenames
-
-#ifndef SWIG  // variadic templates
-// Join multiple paths together, without introducing unnecessary path
-// separators.
-// For example:
-//
-//  Arguments                  | JoinPath
-//  ---------------------------+----------
-//  '/foo', 'bar'              | /foo/bar
-//  '/foo/', 'bar'             | /foo/bar
-//  '/foo', '/bar'             | /foo/bar
-//
-// Usage:
-// string path = io::JoinPath("/mydir", filename);
-// string path = io::JoinPath(FLAGS_test_srcdir, filename);
-// string path = io::JoinPath("/full", "path", "to", "filename");
-template 
-string JoinPath(const T&... args) {
-  return internal::JoinPathImpl({args...});
-}
-#endif /* SWIG */
-
-// Return true if path is absolute.
-bool IsAbsolutePath(tensorflow::StringPiece path);
-
-// Returns the part of the path before the final "/".  If there is a single
-// leading "/" in the path, the result will be the leading "/".  If there is
-// no "/" in the path, the result is the empty prefix of the input.
-tensorflow::StringPiece Dirname(tensorflow::StringPiece path);
-
-// Returns the part of the path after the final "/".  If there is no
-// "/" in the path, the result is the same as the input.
-tensorflow::StringPiece Basename(tensorflow::StringPiece path);
-
-// Returns the part of the basename of path after the final ".".  If
-// there is no "." in the basename, the result is empty.
-tensorflow::StringPiece Extension(tensorflow::StringPiece path);
-
-// Collapse duplicate "/"s, resolve ".." and "." path elements, remove
-// trailing "/".
-//
-// NOTE: This respects relative vs. absolute paths, but does not
-// invoke any system calls (getcwd(2)) in order to resolve relative
-// paths with respect to the actual working directory.  That is, this is purely
-// string manipulation, completely independent of process state.
-string CleanPath(tensorflow::StringPiece path);
-
-// Populates the scheme, host, and path from a URI. scheme, host, and path are
-// guaranteed by this function to point into the contents of uri, even if
-// empty.
-//
-// Corner cases:
-// - If the URI is invalid, scheme and host are set to empty strings and the
-//   passed string is assumed to be a path
-// - If the URI omits the path (e.g. file://host), then the path is left empty.
-void ParseURI(tensorflow::StringPiece uri, tensorflow::StringPiece* scheme,
-              tensorflow::StringPiece* host, tensorflow::StringPiece* path);
-
-// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
-// return the path.
-string CreateURI(tensorflow::StringPiece scheme, tensorflow::StringPiece host,
-                 tensorflow::StringPiece path);
-
-// Creates a temporary file name with an extension.
-string GetTempFilename(const string& extension);
-
-}  // namespace io
-}  // namespace tensorflow
+#include "tensorflow/core/platform/path.h"
 
 #endif  // TENSORFLOW_CORE_LIB_IO_PATH_H_
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index f743e01ba8a..ecc44e39c11 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -303,6 +303,21 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "path",
+    srcs = ["path.cc"],
+    hdrs = ["path.h"],
+    deps = [
+        ":logging",
+        ":mutex",
+        ":scanner",
+        ":strcat",
+        ":stringpiece",
+        ":types",
+    ],
+    alwayslink = True,
+)
+
 cc_library(
     name = "platform",
     hdrs = ["platform.h"],
@@ -721,6 +736,7 @@ filegroup(
             "abi.cc",
             "cpu_info.cc",
             "numbers.cc",
+            "path.cc",
             "platform_strings.cc",
             "protobuf.cc",
             "random.cc",
@@ -832,6 +848,7 @@ filegroup(
             "file_system_helper.cc",
             "logger.cc",
             "numbers.cc",
+            "path.cc",
             "platform_strings.cc",
             "protobuf.cc",
             "protobuf_util.cc",
diff --git a/tensorflow/core/platform/default/build_refactor.bzl b/tensorflow/core/platform/default/build_refactor.bzl
index 12246af4460..5dbada0a08e 100644
--- a/tensorflow/core/platform/default/build_refactor.bzl
+++ b/tensorflow/core/platform/default/build_refactor.bzl
@@ -80,7 +80,6 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
             "//tensorflow/core/lib/core:blocking_counter",
             "//tensorflow/core/lib/core:error_codes_proto_cc",
             "//tensorflow/core/lib/core:stringpiece",
-            "//tensorflow/core/lib/io:path",
             "//tensorflow/core/platform",
             "//tensorflow/core/platform:context",
             "//tensorflow/core/platform:cord",
@@ -93,6 +92,7 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
             "//tensorflow/core/platform:logging",
             "//tensorflow/core/platform:macros",
             "//tensorflow/core/platform:mutex",
+            "//tensorflow/core/platform:path",
             "//tensorflow/core/platform:platform_port",
             "//tensorflow/core/platform:protobuf",
             "//tensorflow/core/platform:setround",
@@ -405,7 +405,6 @@ TF_WINDOWS_PLATFORM_LIBRARIES = {
             "//tensorflow/core/lib/core:blocking_counter",
             "//tensorflow/core/lib/core:error_codes_proto_cc",
             "//tensorflow/core/lib/core:stringpiece",
-            "//tensorflow/core/lib/io:path",
             "//tensorflow/core/platform",
             "//tensorflow/core/platform:context",
             "//tensorflow/core/platform:cord",
@@ -418,6 +417,7 @@ TF_WINDOWS_PLATFORM_LIBRARIES = {
             "//tensorflow/core/platform:logging",
             "//tensorflow/core/platform:macros",
             "//tensorflow/core/platform:mutex",
+            "//tensorflow/core/platform:path",
             "//tensorflow/core/platform:platform_port",
             "//tensorflow/core/platform:protobuf",
             "//tensorflow/core/platform:setround",
diff --git a/tensorflow/core/platform/default/posix_file_system.h b/tensorflow/core/platform/default/posix_file_system.h
index 752eccea66b..c418a08e944 100644
--- a/tensorflow/core/platform/default/posix_file_system.h
+++ b/tensorflow/core/platform/default/posix_file_system.h
@@ -16,8 +16,8 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
 #define TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
 
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/path.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 301b4c0e81e..ee4ae92f905 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -21,10 +21,10 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/host_info.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/stringprintf.h"
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index 3a1d40e50e2..fb013b82570 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -20,9 +20,9 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/strcat.h"
diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc
index c909b36688e..da3acba7d1a 100644
--- a/tensorflow/core/platform/file_system_helper.cc
+++ b/tensorflow/core/platform/file_system_helper.cc
@@ -19,9 +19,9 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/str_util.h"
diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/platform/path.cc
similarity index 99%
rename from tensorflow/core/lib/io/path.cc
rename to tensorflow/core/platform/path.cc
index ea9d93629c9..864bf49b2bb 100644
--- a/tensorflow/core/lib/io/path.cc
+++ b/tensorflow/core/platform/path.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/path.h"
 
 #include 
 #include 
diff --git a/tensorflow/core/platform/path.h b/tensorflow/core/platform/path.h
new file mode 100644
index 00000000000..db0348d8960
--- /dev/null
+++ b/tensorflow/core/platform/path.h
@@ -0,0 +1,98 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_PATH_H_
+#define TENSORFLOW_CORE_PLATFORM_PATH_H_
+
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace io {
+namespace internal {
+string JoinPathImpl(std::initializer_list paths);
+}
+
+// Utility routines for processing filenames
+
+#ifndef SWIG  // variadic templates
+// Join multiple paths together, without introducing unnecessary path
+// separators.
+// For example:
+//
+//  Arguments                  | JoinPath
+//  ---------------------------+----------
+//  '/foo', 'bar'              | /foo/bar
+//  '/foo/', 'bar'             | /foo/bar
+//  '/foo', '/bar'             | /foo/bar
+//
+// Usage:
+// string path = io::JoinPath("/mydir", filename);
+// string path = io::JoinPath(FLAGS_test_srcdir, filename);
+// string path = io::JoinPath("/full", "path", "to", "filename");
+template 
+string JoinPath(const T&... args) {
+  return internal::JoinPathImpl({args...});
+}
+#endif /* SWIG */
+
+// Return true if path is absolute.
+bool IsAbsolutePath(tensorflow::StringPiece path);
+
+// Returns the part of the path before the final "/".  If there is a single
+// leading "/" in the path, the result will be the leading "/".  If there is
+// no "/" in the path, the result is the empty prefix of the input.
+tensorflow::StringPiece Dirname(tensorflow::StringPiece path);
+
+// Returns the part of the path after the final "/".  If there is no
+// "/" in the path, the result is the same as the input.
+tensorflow::StringPiece Basename(tensorflow::StringPiece path);
+
+// Returns the part of the basename of path after the final ".".  If
+// there is no "." in the basename, the result is empty.
+tensorflow::StringPiece Extension(tensorflow::StringPiece path);
+
+// Collapse duplicate "/"s, resolve ".." and "." path elements, remove
+// trailing "/".
+//
+// NOTE: This respects relative vs. absolute paths, but does not
+// invoke any system calls (getcwd(2)) in order to resolve relative
+// paths with respect to the actual working directory.  That is, this is purely
+// string manipulation, completely independent of process state.
+string CleanPath(tensorflow::StringPiece path);
+
+// Populates the scheme, host, and path from a URI. scheme, host, and path are
+// guaranteed by this function to point into the contents of uri, even if
+// empty.
+//
+// Corner cases:
+// - If the URI is invalid, scheme and host are set to empty strings and the
+//   passed string is assumed to be a path
+// - If the URI omits the path (e.g. file://host), then the path is left empty.
+void ParseURI(tensorflow::StringPiece uri, tensorflow::StringPiece* scheme,
+              tensorflow::StringPiece* host, tensorflow::StringPiece* path);
+
+// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
+// return the path.
+string CreateURI(tensorflow::StringPiece scheme, tensorflow::StringPiece host,
+                 tensorflow::StringPiece path);
+
+// Creates a temporary file name with an extension.
+string GetTempFilename(const string& extension);
+
+}  // namespace io
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_PATH_H_

From dbceeb467420cd8e69d238830947971594b2cbee Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Sat, 30 Nov 2019 01:03:03 -0800
Subject: [PATCH 117/279] compat: Update forward compatibility horizon to
 2019-11-30

PiperOrigin-RevId: 283124384
Change-Id: I50f4a2421a15b7e10b2c02d1234a98bdb870c412
---
 tensorflow/python/compat/compat.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index fdf52af0ee5..65718014f30 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 11, 29)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 11, 30)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 

From d59e6c0bfc2b8eb0cba2d835f4ba00b505fc73b1 Mon Sep 17 00:00:00 2001
From: Jiri Simsa 
Date: Sat, 30 Nov 2019 07:43:54 -0800
Subject: [PATCH 118/279] [tf.data] Migrating static optimization tests to use
 TF combinations.

PiperOrigin-RevId: 283145321
Change-Id: Ic12159919f9e77624986ac5ed3753276f5179ce2
---
 .../kernel_tests/auto_shard_dataset_test.py   |   2 +-
 .../kernel_tests/optimization/BUILD           |  18 +
 .../choose_fastest_branch_dataset_test.py     |  11 +-
 .../choose_fastest_dataset_test.py            |  30 +-
 .../optimization/filter_fusion_test.py        |  55 +-
 .../filter_with_random_uniform_fusion_test.py |   9 +-
 .../optimization/hoist_random_uniform_test.py |  60 ++-
 .../optimization/inject_prefetch_test.py      |  24 +-
 .../optimization/latency_all_edges_test.py    |  11 +-
 .../optimization/map_and_batch_fusion_test.py |   8 +-
 .../map_and_filter_fusion_test.py             |  81 +--
 .../optimization/map_fusion_test.py           |  67 +--
 .../optimization/map_parallelization_test.py  |  50 +-
 .../optimization/map_vectorization_test.py    | 477 +++++++++++-------
 .../optimization/noop_elimination_test.py     |   8 +-
 .../shuffle_and_repeat_fusion_test.py         |   9 +-
 16 files changed, 521 insertions(+), 399 deletions(-)

diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
index 73e68ebcf42..5f13bdae849 100644
--- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
@@ -23,8 +23,8 @@ from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_
 from tensorflow.python.data.experimental.ops import distribute
 from tensorflow.python.data.experimental.ops import distribute_options
 from tensorflow.python.data.experimental.ops import interleave_ops
-from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.experimental.ops import unique
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index 1e9d1ca1d00..4cd2a3d1fcd 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -100,6 +100,24 @@ tf_py_test(
     ],
 )
 
+tf_py_test(
+    name = "inject_prefetch_test",
+    size = "small",
+    srcs = ["inject_prefetch_test.py"],
+    additional_deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python/data/experimental/ops:testing",
+        "//tensorflow/python/data/kernel_tests:test_base",
+        "//tensorflow/python/data/ops:dataset_ops",
+    ],
+    tags = [
+        "no_oss",
+        "no_pip",
+        "no_windows",
+    ],
+)
+
 tf_py_test(
     name = "latency_all_edges_test",
     size = "small",
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py
index ee05ae0603d..bb7849fb213 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py
@@ -23,18 +23,18 @@ from tensorflow.python.data.experimental.ops import optimization
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
                                      parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testSimple(self):
     dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
 
@@ -49,6 +49,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
         expected_output=[0, 1, 2, 3, 4],
         expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testCaptureSimple(self):
     dataset = dataset_ops.Dataset.range(10)
 
@@ -67,6 +68,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
     self.assertDatasetProduces(
         choose_fastest, expected_output=list(range(1, 11)))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testDifferentFunctions(self):
     dataset = dataset_ops.Dataset.range(100)
 
@@ -83,6 +85,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
         choose_fastest,
         expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWithRepeatBeforeAndAfter(self):
     dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
 
@@ -99,6 +102,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
     self.assertDatasetProduces(
         choose_fastest, expected_output=[[0] * 10 for _ in range(10)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWithPrefetch(self):
     """Should maintain ordering even if the branches do prefetching."""
     dataset = dataset_ops.Dataset.range(100)
@@ -114,6 +118,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
 
     self.assertDatasetProduces(choose_fastest, expected_output=list(range(100)))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWithMoreOutputThanInput(self):
 
     dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
@@ -128,6 +133,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
 
     self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWithBadNumElements(self):
 
     dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
@@ -153,6 +159,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
           choose_fastest,
           expected_error=(errors.InvalidArgumentError, expected_error_msg))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testErrorWithRepeat(self):
     dataset = dataset_ops.Dataset.from_tensors(0)
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_dataset_test.py
index 3e51de9f1ee..6e0d9842c48 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_dataset_test.py
@@ -23,15 +23,15 @@ from tensorflow.python.data.experimental.ops import optimization
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class ChooseFastestDatasetTest(test_base.DatasetTestBase,
                                parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testChooseFastestSimple(self):
     dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
     merge = optimization._ChooseFastestDataset([dataset, dataset])
@@ -40,6 +40,7 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
         expected_output=[0, 1, 2, 3, 4],
         expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testChooseFastestManyInputs(self):
     dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
     merge = optimization._ChooseFastestDataset([dataset for _ in range(5)])
@@ -48,6 +49,7 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
         expected_output=[0, 1, 2, 3, 4],
         expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testChooseFastest(self):
     dataset = dataset_ops.Dataset.range(600)
     f = lambda x: 2 * x
@@ -61,11 +63,25 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
         ],
         expected_shapes=dataset_ops.get_legacy_output_shapes(dataset_a))
 
-  @parameterized.named_parameters(
-      ("Shapes", [0], [[1, 2, 3]], "must have compatible output shapes."),
-      ("Types", [0], [0.0], "must have the same output types."),
-      ("NumComponents", [0], ([0], [1]), "must have the same output types."),
-      ("Cardinality", [1, 2, 3], [1], "must have compatible cardinalities."))
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              slices_a=[[0]],
+              slices_b=[[[1, 2, 3]]],
+              error_msg="must have compatible output shapes.") +
+          combinations.combine(
+              slices_a=[[0]],
+              slices_b=[[0.0]],
+              error_msg="must have the same output types.") +
+          combinations.combine(
+              slices_a=[[0]],
+              slices_b=[([0], [1])],
+              error_msg="must have the same output types.") +
+          combinations.combine(
+              slices_a=[[1, 2, 3]],
+              slices_b=[[0]],
+              error_msg="must have compatible cardinalities.")))
   def testChooseFastestErrorWithIncompatibleInput(self, slices_a, slices_b,
                                                   error_msg):
     dataset_a = dataset_ops.Dataset.from_tensor_slices(slices_a)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
index 1aa3d636f02..949f9e2e25c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
@@ -22,47 +22,16 @@ from absl.testing import parameterized
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
-def _filter_fusion_test_cases():
-  """Generates test cases for the FilterFusion optimization."""
-
-  take_all = lambda x: constant_op.constant(True)
-  is_zero = lambda x: math_ops.equal(x, 0)
-  greater = lambda x: math_ops.greater(x + 5, 0)
-
-  tests = []
-  filters = [take_all, is_zero, greater]
-  identity = lambda x: x
-  for x, predicate_1 in enumerate(filters):
-    for y, predicate_2 in enumerate(filters):
-      tests.append(("Mixed{}{}".format(x, y), identity,
-                    [predicate_1, predicate_2]))
-      for z, predicate_3 in enumerate(filters):
-        tests.append(("Mixed{}{}{}".format(x, y, z), identity,
-                      [predicate_1, predicate_2, predicate_3]))
-
-  take_all_multiple = lambda x, y: constant_op.constant(True)
-  # Multi output
-  tests.append(("Multi1", lambda x: (x, x),
-                [take_all_multiple, take_all_multiple]))
-  tests.append(("Multi2", lambda x: (x, 2), [
-      take_all_multiple,
-      lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
-  ]))
-  return tuple(tests)
-
-
-@test_util.run_all_in_graph_and_eager_modes
 class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @parameterized.named_parameters(*_filter_fusion_test_cases())
-  def testFilterFusion(self, map_function, predicates):
+  def _testFilterFusion(self, map_function, predicates):
     dataset = dataset_ops.Dataset.range(5).apply(
         testing.assert_next(["Map", "Filter",
                              "MemoryCacheImpl"])).map(map_function)
@@ -91,6 +60,26 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
         expected_output.append(r)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testFilterFusionScalar(self):
+    take_all = lambda x: constant_op.constant(True)
+    is_zero = lambda x: math_ops.equal(x, 0)
+    greater = lambda x: math_ops.greater(x + 5, 0)
+    predicates = [take_all, is_zero, greater]
+    for x in predicates:
+      for y in predicates:
+        self._testFilterFusion(lambda x: x, [x, y])
+        for z in predicates:
+          self._testFilterFusion(lambda x: x, [x, y, z])
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testFilterFusionTuple(self):
+    take_all = lambda x, y: constant_op.constant(True)
+    is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+
+    self._testFilterFusion(lambda x: (x, x), [take_all, take_all])
+    self._testFilterFusion(lambda x: (x, 2), [take_all, is_zero])
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py
index 2b130f40fc9..76006252367 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py
@@ -17,17 +17,20 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.ops import random_ops
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase):
+class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase,
+                                        parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testFilterWithRandomUniformFusion(self):
     dataset = dataset_ops.Dataset.range(10000000).apply(
         testing.assert_next(["Sampling"]))
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
index 928b435fe5c..59f50fa1752 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -22,44 +22,17 @@ from absl.testing import parameterized
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.platform import test
 
 
-def _hoist_random_uniform_test_cases():
-  """Generates test cases for the HoistRandomUniform optimization."""
-
-  plus_one = lambda x: x + 1
-
-  def random(_):
-    return random_ops.random_uniform([],
-                                     minval=1,
-                                     maxval=10,
-                                     dtype=dtypes.float32,
-                                     seed=42)
-
-  def random_with_assert(x):
-    y = random(x)
-    assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
-    with ops.control_dependencies([assert_op]):
-      return y
-
-  twice_random = lambda x: (random(x) + random(x)) / 2.
-
-  tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
-           ("RandomWithAssert", random_with_assert, True),
-           ("TwiceRandom", twice_random, False)]
-  return tuple(tests)
-
-
-@test_util.run_all_in_graph_and_eager_modes
 class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def _testDataset(self, dataset):
@@ -78,11 +51,10 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  @parameterized.named_parameters(*_hoist_random_uniform_test_cases())
-  def testHoisting(self, function, will_optimize):
+  def _testHoistFunction(self, function, should_optimize):
     dataset = dataset_ops.Dataset.range(5).apply(
         testing.assert_next(
-            ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
+            ["Zip[0]", "Map"] if should_optimize else ["Map"])).map(function)
 
     options = dataset_ops.Options()
     options.experimental_optimization.apply_default_optimizations = False
@@ -90,6 +62,32 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
     dataset = dataset.with_options(options)
     self._testDataset(dataset)
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testNoRandom(self):
+    self._testHoistFunction(lambda x: x + 1, should_optimize=False)
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testRandom(self):
+
+    def random(_):
+      return random_ops.random_uniform([],
+                                       minval=1,
+                                       maxval=10,
+                                       dtype=dtypes.float32,
+                                       seed=42)
+
+    def random_with_assert(x):
+      y = random(x)
+      assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
+      with ops.control_dependencies([assert_op]):
+        return y
+
+    self._testHoistFunction(random, should_optimize=True)
+    self._testHoistFunction(random_with_assert, should_optimize=True)
+    self._testHoistFunction(
+        lambda x: (random(x) + random(x)) / 2, should_optimize=False)
+
+  @combinations.generate(test_base.default_test_combinations())
   def testCapturedInputs(self):
     a = constant_op.constant(1, dtype=dtypes.float32)
     b = constant_op.constant(0, dtype=dtypes.float32)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/inject_prefetch_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/inject_prefetch_test.py
index 89f61f141b0..d1a45d7328e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/inject_prefetch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/inject_prefetch_test.py
@@ -17,35 +17,38 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.data.experimental.ops import optimization
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class InjectPrefetchTest(test_base.DatasetTestBase):
+class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def _enable_autotune_buffers(self, dataset):
     options = dataset_ops.Options()
     options.experimental_optimization.autotune_buffers = True
     return dataset.with_options(options)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testParallelMap(self):
     dataset = dataset_ops.Dataset.range(100)
     dataset = dataset.apply(
-        optimization.assert_next(["ParallelMap", "Prefetch", "FiniteTake"]))
+        testing.assert_next(["ParallelMap", "Prefetch", "FiniteTake"]))
     dataset = dataset.map(
         lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
     dataset = dataset.take(50)
     dataset = self._enable_autotune_buffers(dataset)
     self.assertDatasetProduces(dataset, range(1, 51))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testMapAndBatch(self):
     dataset = dataset_ops.Dataset.range(100)
     dataset = dataset.apply(
-        optimization.assert_next(["MapAndBatch", "Prefetch", "FiniteTake"]))
+        testing.assert_next(["MapAndBatch", "Prefetch", "FiniteTake"]))
     dataset = dataset.map(
         lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
     dataset = dataset.batch(10)
@@ -54,10 +57,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
     self.assertDatasetProduces(
         dataset, [list(range(i + 1, i + 11)) for i in range(0, 50, 10)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testParallelInterleaveV2(self):
     dataset = dataset_ops.Dataset.range(100)
     dataset = dataset.apply(
-        optimization.assert_next(
+        testing.assert_next(
             ["ParallelInterleaveV2", "Prefetch", "FiniteTake"]))
     dataset = dataset.interleave(
         lambda x: dataset_ops.Dataset.from_tensors(x + 1),
@@ -66,10 +70,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
     dataset = self._enable_autotune_buffers(dataset)
     self.assertDatasetProduces(dataset, range(1, 51))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testChainedParallelDatasets(self):
     dataset = dataset_ops.Dataset.range(100)
     dataset = dataset.apply(
-        optimization.assert_next([
+        testing.assert_next([
             "ParallelMap", "Prefetch", "ParallelInterleaveV2", "Prefetch",
             "MapAndBatch", "Prefetch", "FiniteTake"
         ]))
@@ -85,9 +90,10 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
     dataset = self._enable_autotune_buffers(dataset)
     self.assertDatasetProduces(dataset, [[i] for i in range(3, 53)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNoRegularMap(self):
     dataset = dataset_ops.Dataset.range(100)
-    dataset = dataset.apply(optimization.assert_next(["Map", "FiniteTake"]))
+    dataset = dataset.apply(testing.assert_next(["Map", "FiniteTake"]))
     dataset = dataset.map(lambda x: x + 1).take(50)
     dataset = self._enable_autotune_buffers(dataset)
     self.assertDatasetProduces(dataset, range(1, 51))
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
index f6e5111cf32..d9ebc1cc719 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
@@ -17,15 +17,22 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
-from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.experimental.ops import stats_aggregator
+from tensorflow.python.data.experimental.ops import testing
+from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 
 
-class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
+class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase,
+                          parameterized.TestCase):
 
+  # TODO(jsimsa): Investigate why are graph-mode tests failing.
+  @combinations.generate(test_base.eager_only_combinations())
   def testLatencyStatsOptimization(self):
     aggregator = stats_aggregator.StatsAggregator()
     dataset = dataset_ops.Dataset.from_tensors(1).apply(
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_batch_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_batch_fusion_test.py
index c7e6fbbf377..622b6ca5671 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_batch_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_batch_fusion_test.py
@@ -17,16 +17,18 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class MapAndBatchFusionTest(test_base.DatasetTestBase):
+class MapAndBatchFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testMapAndBatchFusion(self):
     dataset = dataset_ops.Dataset.range(10).apply(
         testing.assert_next(
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
index 1e53b4394ae..a0257f76e93 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -22,50 +22,16 @@ from absl.testing import parameterized
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
-def _map_and_filter_fusion_test_cases():
-  """Generates test cases for the MapAndFilterFusion optimization."""
-
-  identity = lambda x: x
-  increment = lambda x: x + 1
-  minus_five = lambda x: x - 5
-
-  def increment_and_square(x):
-    y = x + 1
-    return y * y
-
-  take_all = lambda x: constant_op.constant(True)
-  is_zero = lambda x: math_ops.equal(x, 0)
-  is_odd = lambda x: math_ops.equal(x % 2, 0)
-  greater = lambda x: math_ops.greater(x + 5, 0)
-
-  functions = [identity, increment, minus_five, increment_and_square]
-  filters = [take_all, is_zero, is_odd, greater]
-  tests = []
-
-  for x, fun in enumerate(functions):
-    for y, predicate in enumerate(filters):
-      tests.append(("Mixed{}{}".format(x, y), fun, predicate))
-
-  # Multi output
-  tests.append(("Multi1", lambda x: (x, x),
-                lambda x, y: constant_op.constant(True)))
-  tests.append(
-      ("Multi2", lambda x: (x, 2),
-       lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
-  return tuple(tests)
-
-
-@test_util.run_all_in_graph_and_eager_modes
 class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  def _testMapAndFilter(self, dataset, function, predicate):
+  def _testDataset(self, dataset, function, predicate):
     expected_output = []
     for x in range(10):
       r = function(x)
@@ -77,8 +43,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
         expected_output.append(r)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
-  @parameterized.named_parameters(*_map_and_filter_fusion_test_cases())
-  def testMapFilterFusion(self, function, predicate):
+  def _testMapAndFilterFusion(self, function, predicate):
     dataset = dataset_ops.Dataset.range(10).apply(
         testing.assert_next(["Map", "Filter",
                              "Map"])).map(function).filter(predicate)
@@ -86,8 +51,44 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
     options.experimental_optimization.apply_default_optimizations = False
     options.experimental_optimization.map_and_filter_fusion = True
     dataset = dataset.with_options(options)
-    self._testMapAndFilter(dataset, function, predicate)
+    self._testDataset(dataset, function, predicate)
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testMapAndFilterFusionScalar(self):
+    identity = lambda x: x
+    increment = lambda x: x + 1
+    minus_five = lambda x: x - 5
+
+    def increment_and_square(x):
+      y = x + 1
+      return y * y
+
+    functions = [identity, increment, minus_five, increment_and_square]
+
+    take_all = lambda x: constant_op.constant(True)
+    is_zero = lambda x: math_ops.equal(x, 0)
+    is_odd = lambda x: math_ops.equal(x % 2, 0)
+    greater = lambda x: math_ops.greater(x + 5, 0)
+    predicates = [take_all, is_zero, is_odd, greater]
+
+    for function in functions:
+      for predicate in predicates:
+        self._testMapAndFilterFusion(function, predicate)
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testMapAndFilterFusionTuple(self):
+    replicate = lambda x: (x, x)
+    with_two = lambda x: (x, 2)
+    functions = [replicate, with_two]
+    take_all = lambda x, y: constant_op.constant(True)
+    is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+    predicates = [take_all, is_zero]
+
+    for function in functions:
+      for predicate in predicates:
+        self._testMapAndFilterFusion(function, predicate)
+
+  @combinations.generate(test_base.default_test_combinations())
   def testCapturedInputs(self):
     a = constant_op.constant(3, dtype=dtypes.int64)
     b = constant_op.constant(4, dtype=dtypes.int64)
@@ -104,7 +105,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
     options.experimental_optimization.apply_default_optimizations = False
     options.experimental_optimization.map_and_filter_fusion = True
     dataset = dataset.with_options(options)
-    self._testMapAndFilter(dataset, function, predicate)
+    self._testDataset(dataset, function, predicate)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
index 10f27dc277f..28da0474bc9 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
@@ -22,51 +22,13 @@ from absl.testing import parameterized
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 
 
-def _map_fusion_test_cases():
-  """Generates test cases for the MapFusion optimization."""
-
-  identity = lambda x: x
-  increment = lambda x: x + 1
-
-  def increment_and_square(x):
-    y = x + 1
-    return y * y
-
-  functions = [identity, increment, increment_and_square]
-  tests = []
-  for i, fun1 in enumerate(functions):
-    for j, fun2 in enumerate(functions):
-      tests.append((
-          "Test{}{}".format(i, j),
-          [fun1, fun2],
-      ))
-      for k, fun3 in enumerate(functions):
-        tests.append((
-            "Test{}{}{}".format(i, j, k),
-            [fun1, fun2, fun3],
-        ))
-
-  swap = lambda x, n: (n, x)
-  tests.append((
-      "Swap1",
-      [lambda x: (x, 42), swap],
-  ))
-  tests.append((
-      "Swap2",
-      [lambda x: (x, 42), swap, swap],
-  ))
-  return tuple(tests)
-
-
-@test_util.run_all_in_graph_and_eager_modes
 class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @parameterized.named_parameters(*_map_fusion_test_cases())
-  def testMapFusion(self, functions):
+  def _testMapFusion(self, functions):
     dataset = dataset_ops.Dataset.range(5).apply(
         testing.assert_next(["Map", "MemoryCacheImpl"]))
     for function in functions:
@@ -88,6 +50,31 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
       expected_output.append(r)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testMapFusionScalar(self):
+    identity = lambda x: x
+    increment = lambda x: x + 1
+
+    def increment_and_square(x):
+      y = x + 1
+      return y * y
+
+    functions = [identity, increment, increment_and_square]
+
+    for x in functions:
+      for y in functions:
+        self._testMapFusion([x, y])
+        for z in functions:
+          self._testMapFusion([x, y, z])
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testMapAndFilterFusionTuple(self):
+    with_42 = lambda x: (x, 42)
+    swap = lambda x, y: (y, x)
+
+    self._testMapFusion([with_42, swap])
+    self._testMapFusion([with_42, swap, swap])
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
index 668ab28c64c..a28a3052abc 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
@@ -22,38 +22,20 @@ from absl.testing import parameterized
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
 
-def _map_parallelization_test_cases():
-  """Generates test cases for the MapParallelization optimization."""
-
-  identity = lambda x: x
-  increment = lambda x: x + 1
-
-  def assert_greater(x):
-    assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
-    with ops.control_dependencies([assert_op]):
-      return x
-
-  return (("Identity", identity, True),
-          ("Increment", increment, True),
-          ("AssertGreater", assert_greater, True))
-
-
-@test_util.run_all_in_graph_and_eager_modes
 class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @parameterized.named_parameters(*_map_parallelization_test_cases())
-  def testMapParallelization(self, function, should_be_parallel):
-    next_nodes = ["ParallelMap"] if should_be_parallel else ["Map"]
+  def _testMapParallelization(self, function, should_optimize):
+    next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
     dataset = dataset_ops.Dataset.range(5).apply(
         testing.assert_next(next_nodes)).map(function)
     options = dataset_ops.Options()
@@ -63,9 +45,26 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         dataset, expected_output=[function(x) for x in range(5)])
 
-  def testMapParallelizationWithCapturedConstant(self):
-    """Tests that functions with captured constants are parallelized."""
+  @combinations.generate(test_base.default_test_combinations())
+  def testIdentity(self):
+    self._testMapParallelization(lambda x: x, should_optimize=True)
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testIncrement(self):
+    self._testMapParallelization(lambda x: x + 1, should_optimize=True)
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testAssert(self):
+
+    def assert_greater(x):
+      assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
+      with ops.control_dependencies([assert_op]):
+        return x
+
+    self._testMapParallelization(assert_greater, should_optimize=True)
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testCapturedConstant(self):
     captured_t = constant_op.constant(42, dtype=dtypes.int64)
     def fn(x):
       return x + captured_t
@@ -78,9 +77,8 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         dataset, expected_output=[x + 42 for x in range(5)])
 
-  def testMapParallelizationWithCapturedVariable(self):
-    """Tests that functions with captured variables are not parallelized."""
-
+  @combinations.generate(test_base.default_test_combinations())
+  def testCapturedVariable(self):
     captured_t = variables.Variable(42, dtype=dtypes.int64)
     def fn(x):
       return x + captured_t
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index f17d863e555..4569f171f75 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
+
 from absl.testing import parameterized
 import numpy as np
 
@@ -26,12 +28,12 @@ from tensorflow.python.data.experimental.ops import batching
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import bitwise_ops
 from tensorflow.python.ops import check_ops
@@ -43,21 +45,45 @@ from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import test
 
 
-def _generate_unary_cwise_math_cases():
-  # TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared.
-  bitwise_cases = [("Invert", bitwise_ops.invert)]
-  logical_cases = [("LogicalNot", math_ops.logical_not)]
-  complex_cases = [
+def _generate_test_combinations(cases):
+
+  def reduce_fn(x, y):
+    name, fn = y
+    return x + combinations.combine(map_fn=combinations.NamedObject(name, fn))
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
+def _unary_bitwise_test_combinations():
+  cases = [("Invert", bitwise_ops.invert)]
+  return _generate_test_combinations(cases)
+
+
+def _unary_logical_test_combinations():
+  cases = [("LogicalNot", math_ops.logical_not)]
+  return _generate_test_combinations(cases)
+
+
+def _unary_complex_test_combinations():
+  cases = [
       ("Angle", math_ops.angle),
       ("ComplexAbs", math_ops.abs),
       ("Conj", math_ops.conj),
       ("Imag", math_ops.imag),
       ("Real", math_ops.real),
   ]
-  real_cases = [
+  return _generate_test_combinations(cases)
+
+
+def _unary_real_test_combinations():
+  # acosh requires values x >= 1
+  def safe_acosh(x):
+    return math_ops.acosh(1 + math_ops.square(x))
+
+  cases = [
       ("Abs", math_ops.abs),
       ("Acos", math_ops.acos),
-      ("Acosh", lambda x: math_ops.acosh(1 + math_ops.square(x))),
+      ("Acosh", safe_acosh),
       ("Asin", math_ops.asin),
       ("Asinh", math_ops.asinh),
       ("Atan", math_ops.atan),
@@ -99,45 +125,26 @@ def _generate_unary_cwise_math_cases():
       ("Tan", math_ops.tan),
       ("Tanh", math_ops.tanh),
   ]
-  random_input = np.random.rand(3, 5)
-  complex_component = np.random.rand(3, 5)
-  random_int = np.random.randint(0, 10, (7, 3, 5))
-
-  def bitwise_dataset_factory():
-    return dataset_ops.Dataset.from_tensor_slices(random_int)
-
-  def logical_dataset_factory():
-    return dataset_ops.Dataset.from_tensor_slices(random_input > 0)
-
-  def random_dataset_factory():
-    return dataset_ops.Dataset.from_tensor_slices(random_input)
-
-  def complex_dataset_factory():
-    return dataset_ops.Dataset.from_tensor_slices(
-        math_ops.complex(random_input, complex_component))
-
-  case_factory_pairs = [
-      (bitwise_cases, bitwise_dataset_factory),
-      (logical_cases, logical_dataset_factory),
-      (complex_cases, complex_dataset_factory),
-      (real_cases, random_dataset_factory),
-  ]
-  return [(case[0], case[1], factory)
-          for cases, factory in case_factory_pairs
-          for case in cases]
+  return _generate_test_combinations(cases)
 
 
-def _generate_binary_cwise_math_cases():
-  bitwise_cases = [("BitwiseAnd", bitwise_ops.bitwise_and),
-                   ("BitwiseOr", bitwise_ops.bitwise_or),
-                   ("BitwiseXor", bitwise_ops.bitwise_xor),
-                   ("LeftShift", bitwise_ops.left_shift),
-                   ("RightShift", bitwise_ops.right_shift)]
+def _binary_bitwise_test_combinations():
+  cases = [("BitwiseAnd", bitwise_ops.bitwise_and),
+           ("BitwiseOr", bitwise_ops.bitwise_or),
+           ("BitwiseXor", bitwise_ops.bitwise_xor),
+           ("LeftShift", bitwise_ops.left_shift),
+           ("RightShift", bitwise_ops.right_shift)]
+  return _generate_test_combinations(cases)
 
-  logical_cases = [("LogicalAnd", math_ops.logical_and),
-                   ("LogicalOr", math_ops.logical_or)]
 
-  # Wrapper functions restricting the range of inputs of zeta and polygamma.
+def _binary_logical_test_combinations():
+  cases = [("LogicalAnd", math_ops.logical_and),
+           ("LogicalOr", math_ops.logical_or)]
+  return _generate_test_combinations(cases)
+
+
+def _binary_real_test_combinations():
+
   def safe_polygamma(x, y):
     return math_ops.polygamma(
         math_ops.round(clip_ops.clip_by_value(y, 1, 10)), x * x + 1)
@@ -145,7 +152,7 @@ def _generate_binary_cwise_math_cases():
   def safe_zeta(x, y):
     return math_ops.zeta(x * x + 1, y * y)
 
-  real_cases = [
+  cases = [
       ("Add", math_ops.add),
       ("AddV2", math_ops.add_v2),
       ("Atan2", math_ops.atan2),
@@ -174,150 +181,10 @@ def _generate_binary_cwise_math_cases():
       ("TruncateMod", math_ops.truncate_mod),
       ("Zeta", safe_zeta),
   ]
-
-  # Exercises broadcasting capabilities
-  x = np.random.rand(7, 3, 5)
-  y = np.random.rand(3, 5)
-
-  x_int = np.random.randint(0, 10, (7, 3, 5))
-  y_int = np.random.randint(0, 10, (3, 5))
-
-  def bitwise_dataset_factory():
-    return dataset_ops.Dataset.from_tensors((x_int, y_int))
-
-  def logical_dataset_factory():
-    return dataset_ops.Dataset.from_tensors((x > 0, y > 0))
-
-  def random_dataset_factory():
-    return dataset_ops.Dataset.from_tensors((x, y))
-
-  case_factory_pairs = [
-      (bitwise_cases, bitwise_dataset_factory),
-      (logical_cases, logical_dataset_factory),
-      (real_cases, random_dataset_factory),
-  ]
-  return [(case[0], case[1], factory)
-          for cases, factory in case_factory_pairs
-          for case in cases]
+  return _generate_test_combinations(cases)
 
 
-def _generate_cwise_test_cases():
-  return _generate_unary_cwise_math_cases() + _generate_binary_cwise_math_cases(
-  )
-
-
-def _generate_csv_test_case():
-
-  def csv_factory():
-    return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a",
-                                                   "2.4:5:c"]).repeat(5)
-
-  def decode_csv_fn(x):
-    return parsing_ops.decode_csv(
-        x,
-        record_defaults=[
-            constant_op.constant([], dtypes.float32),
-            constant_op.constant([], dtypes.int32),
-            constant_op.constant([], dtypes.string)
-        ],
-        field_delim=":")
-
-  return decode_csv_fn, csv_factory
-
-
-def _generate_parse_single_example_test_case():
-  # When sparse tensors are used, map_vectorization is not
-  # attempted because the output_shapes of the map dataset are not defined.
-  # TODO(rachelim): Consider being more lax with checking the output_shapes of
-  # the map node.
-
-  def parse_example_factory():
-
-    def _int64_feature(*values):
-      return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values))
-
-    def _bytes_feature(*values):
-      return feature_pb2.Feature(
-          bytes_list=feature_pb2.BytesList(
-              value=[v.encode("utf-8") for v in values]))
-
-    return dataset_ops.Dataset.from_tensor_slices(
-        constant_op.constant([
-            example_pb2.Example(
-                features=feature_pb2.Features(
-                    feature={
-                        "dense_int": _int64_feature(i),
-                        "dense_str": _bytes_feature(str(i)),
-                    })).SerializeToString() for i in range(10)
-        ]))
-
-  def parse_single_example_fn(x):
-    features = {
-        "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
-        "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
-    }
-    return parsing_ops.parse_single_example(x, features)
-
-  return parse_single_example_fn, parse_example_factory
-
-
-def _generate_optimization_test_cases():
-
-  def base_dataset_factory():
-    return dataset_ops.Dataset.from_tensors(np.random.rand(10, 3)).repeat(5)
-
-  rand_val = np.random.rand(1, 1, 1, 1, 1, 1)
-
-  csv_test_case = _generate_csv_test_case()
-  parse_fn, parse_base = _generate_parse_single_example_test_case()
-
-  def dense_output_only_parse_fn(x):
-    # Since we haven't implemented a vectorizer for SerializeSparse, any
-    # function with sparse outputs will only be naively vectorized.
-    parse_result = parse_fn(x)
-    return [
-        y for y in parse_result if not isinstance(y, sparse_tensor.SparseTensor)
-    ]
-
-  def map_fn_with_cycle(x):
-    c = lambda i: math_ops.less(i, 10)
-    b = lambda i: math_ops.add(i, 1)
-    return control_flow_ops.while_loop(c, b, [x])
-
-  # Misc test cases
-  test_cases = [
-      ("Basic", lambda x: (x, x + 1), base_dataset_factory),
-      ("Broadcast", lambda x: x + rand_val, base_dataset_factory),
-      ("Cycle", map_fn_with_cycle, lambda: dataset_ops.Dataset.from_tensors(1)),
-      ("Const", lambda x: 2, base_dataset_factory),
-      ("Cast", lambda x: math_ops.cast(x, dtypes.float64),
-       base_dataset_factory),
-      ("Reshape", lambda x: array_ops.reshape(x, (-1, 30)),
-       base_dataset_factory),
-      ("Transpose", array_ops.transpose, base_dataset_factory),
-      ("Unpack", array_ops.unstack, base_dataset_factory),
-      ("UnpackNegativeAxis", lambda x: array_ops.unstack(x, axis=-1),
-       base_dataset_factory),
-      # Parsing ops
-      ("DecodeCSV", csv_test_case[0], csv_test_case[1]),
-      ("ParseSingleExample", parse_fn, parse_base),
-      ("ParseSingleExampleDenseOutputOnly", dense_output_only_parse_fn,
-       parse_base),
-  ] + _generate_cwise_test_cases()
-
-  return [{
-      "testcase_name":
-          x[0] + "Parallel" if num_parallel_calls is not None else x[0],
-      "map_fn":
-          x[1],
-      "base_dataset_factory":
-          x[2],
-      "num_parallel_calls":
-          num_parallel_calls
-  } for x in test_cases for num_parallel_calls in (None, 12)]
-
-
-@test_util.run_all_in_graph_and_eager_modes
+# TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared.
 class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def _enable_map_vectorization(self, dataset, use_choose=True):
@@ -370,13 +237,223 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
     optimized = self._enable_map_vectorization(optimized)
     return unoptimized, optimized
 
-  @parameterized.named_parameters(_generate_optimization_test_cases())
-  def testOptimization(self, map_fn, base_dataset_factory, num_parallel_calls):
-    base_dataset = base_dataset_factory()
-    unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
+  def _testOptimization(self, map_fn, dataset_factory, num_parallel_calls):
+    dataset = dataset_factory()
+    unoptimized, optimized = self._get_test_datasets(dataset, map_fn,
                                                      num_parallel_calls)
     self.assertDatasetsEqual(unoptimized, optimized)
 
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testBasic(self, num_parallel_calls):
+    data = np.random.rand(10, 3)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
+    map_fn = lambda x: (x, x + 1)
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testBroadcast(self, num_parallel_calls):
+    data = np.random.rand(10, 3)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
+    value = np.random.rand(1, 1, 1, 1, 1, 1)
+    map_fn = lambda x: x + value
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testCast(self, num_parallel_calls):
+    data = np.random.rand(10, 3)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
+    map_fn = lambda x: math_ops.cast(x, dtypes.float64)
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testConst(self, num_parallel_calls):
+    data = np.random.rand(10, 3)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
+    map_fn = lambda x: 2
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testCycle(self, num_parallel_calls):
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors(1)
+
+    def map_fn(x):
+      c = lambda i: math_ops.less(i, 10)
+      b = lambda i: math_ops.add(i, 1)
+      return control_flow_ops.while_loop(c, b, [x])
+
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testReshape(self, num_parallel_calls):
+    data = np.random.rand(10, 3)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
+    map_fn = lambda x: array_ops.reshape(x, (-1, 30))
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testTranspose(self, num_parallel_calls):
+    data = np.random.rand(10, 3)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
+    map_fn = array_ops.transpose
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testUnstack(self, num_parallel_calls):
+    data = np.random.rand(10, 3)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
+    map_fns = [array_ops.unstack, lambda x: array_ops.unstack(x, axis=-1)]
+    for map_fn in map_fns:
+      self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _unary_bitwise_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testUnaryBitwiseOperations(self, map_fn, num_parallel_calls):
+    x = np.random.randint(0, 10, (7, 3, 5))
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _unary_logical_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testUnaryLogicalOperations(self, map_fn, num_parallel_calls):
+    x = np.random.rand(3, 5)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x > 0)
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _unary_complex_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testUnaryComplexOperations(self, map_fn, num_parallel_calls):
+    x = math_ops.complex(np.random.rand(3, 5), np.random.rand(3, 5))
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _unary_real_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testUnaryRealOperations(self, map_fn, num_parallel_calls):
+    x = np.random.rand(3, 5)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _binary_bitwise_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testBinaryBitwiseOperations(self, map_fn, num_parallel_calls):
+    x = np.random.randint(0, 10, (7, 3, 5))
+    y = np.random.randint(0, 10, (3, 5))
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x, y))
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _binary_logical_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testBinaryLogicalOperations(self, map_fn, num_parallel_calls):
+    x = np.random.rand(7, 3, 5)
+    y = np.random.rand(3, 5)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x > 0, y > 0))
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _binary_real_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testBinaryRealOperations(self, map_fn, num_parallel_calls):
+    x = np.random.rand(7, 3, 5)
+    y = np.random.rand(3, 5)
+    dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x, y))
+    self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testDecodeCsv(self, num_parallel_calls):
+
+    def dataset_factory():
+      return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a",
+                                                     "2.4:5:c"]).repeat(5)
+
+    def decode_csv_fn(x):
+      return parsing_ops.decode_csv(
+          x,
+          record_defaults=[
+              constant_op.constant([], dtypes.float32),
+              constant_op.constant([], dtypes.int32),
+              constant_op.constant([], dtypes.string)
+          ],
+          field_delim=":")
+
+    self._testOptimization(decode_csv_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testParseSingleExample(self, num_parallel_calls):
+
+    def dataset_factory():
+
+      def _int64_feature(*values):
+        return feature_pb2.Feature(
+            int64_list=feature_pb2.Int64List(value=values))
+
+      def _bytes_feature(*values):
+        return feature_pb2.Feature(
+            bytes_list=feature_pb2.BytesList(
+                value=[v.encode("utf-8") for v in values]))
+
+      # pylint:disable=g-complex-comprehension
+      return dataset_ops.Dataset.from_tensor_slices(
+          constant_op.constant([
+              example_pb2.Example(
+                  features=feature_pb2.Features(
+                      feature={
+                          "dense_int": _int64_feature(i),
+                          "dense_str": _bytes_feature(str(i)),
+                      })).SerializeToString() for i in range(10)
+          ]))
+
+    def parse_fn(x):
+      features = {
+          "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
+          "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
+      }
+      return parsing_ops.parse_single_example(x, features)
+
+    def dense_only_parse_fn(x):
+      return [
+          y for y in parse_fn(x)
+          if not isinstance(y, sparse_tensor.SparseTensor)
+      ]
+
+    map_fns = [parse_fn, dense_only_parse_fn]
+
+    for map_fn in map_fns:
+      self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
+
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationBadMapFn(self):
     # Test map functions that give an error
     def map_fn(x):
@@ -391,6 +468,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
       nxt = dataset_ops.make_one_shot_iterator(optimized).get_next()
       self.evaluate(nxt)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationWithCapturedInputs(self):
     # Tests that vectorization works with captured inputs.
     y = constant_op.constant(1, shape=(2,))
@@ -405,6 +483,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
         base_dataset, map_fn, expect_optimized=True)
     self.assertDatasetsEqual(optimized, unoptimized)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationWithMapAndBatchFusion(self):
     # Tests that vectorization works on fused map and batch.
     def map_fn(x):
@@ -425,12 +504,11 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
     optimized = self._enable_map_vectorization(optimized)
     self.assertDatasetsEqual(optimized, unoptimized)
 
-  @parameterized.named_parameters(
-      ("1", True, True),
-      ("2", True, False),
-      ("3", False, True),
-      ("4", False, False),
-  )
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              fuse_first=[True, False], fuse_second=[True, False])))
   def testOptimizationWithChainedMapAndBatch(self, fuse_first, fuse_second):
     # Tests that vectorization works on chained map and batch functions.
     def map_fn(x):
@@ -474,6 +552,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
     optimized = self._enable_map_vectorization(optimized)
     self.assertDatasetsEqual(optimized, unoptimized)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationIgnoreStateful(self):
 
     def map_fn(x):
@@ -488,6 +567,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
       get_next = self.getNext(dataset)
       self.evaluate(get_next())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationIgnoreRagged(self):
     # Make sure we ignore inputs that might not be uniformly sized
     def map_fn(x):
@@ -499,6 +579,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
         base_dataset, map_fn, expect_optimized=False)
     self.assertDatasetsEqual(unoptimized, optimized)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationIgnoreRaggedMap(self):
     # Don't optimize when the output of the map fn shapes are unknown.
     def map_fn(x):
@@ -512,6 +593,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
       get_next = self.getNext(dataset)
       self.evaluate(get_next())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationWithUnknownBatchShape(self):
     tensor = sparse_tensor.SparseTensor(
         indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
@@ -526,6 +608,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
     optimized = self._enable_map_vectorization(unoptimized)
     self.assertDatasetsEqual(unoptimized, optimized)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationWithSparseTensor(self):
     base_dataset = dataset_ops.Dataset.from_tensors(0)
 
@@ -542,6 +625,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
     optimized = self._enable_map_vectorization(unoptimized)
     self.assertDatasetsEqual(unoptimized, optimized)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationWithPrefetch(self):
     dataset = dataset_ops.Dataset.range(10)
     dataset = dataset.map(lambda x: x)
@@ -550,6 +634,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
     dataset = self._enable_map_vectorization(dataset)
     self.assertDatasetProduces(dataset, [list(range(10))])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptimizationWithoutChooseFastest(self):
     dataset = dataset_ops.Dataset.range(10)
     dataset = dataset.map(lambda x: x**2)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
index a401a5c8baf..84ef45d9593 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
@@ -17,19 +17,21 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class NoopEliminationTest(test_base.DatasetTestBase):
+class NoopEliminationTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNoopElimination(self):
     a = constant_op.constant(1, dtype=dtypes.int64)
     b = constant_op.constant(2, dtype=dtypes.int64)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py
index 4da7fa27d58..ad1a98134b8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py
@@ -17,19 +17,22 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python import tf2
 from tensorflow.python.data.experimental.ops import testing
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase):
+class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase,
+                                 parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testShuffleAndRepeatFusion(self):
     if tf2.enabled() and context.executing_eagerly():
       expected = "Shuffle"

From ed17bc62837605a8802756c751d895d54f998c73 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sat, 30 Nov 2019 13:49:26 -0800
Subject: [PATCH 119/279] Remove dependence on core/lib/core/threadpool under
 tf/core/platform.

PiperOrigin-RevId: 283163508
Change-Id: Id27296622591efa97865f4328919b2a3705765ca
---
 tensorflow/core/platform/port_test.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/core/platform/port_test.cc b/tensorflow/core/platform/port_test.cc
index 94a9e4d4589..4f59ed6f1c5 100644
--- a/tensorflow/core/platform/port_test.cc
+++ b/tensorflow/core/platform/port_test.cc
@@ -15,12 +15,12 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/platform/cpu_info.h"
 #include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/threadpool.h"
 
 namespace tensorflow {
 namespace port {

From 699e26023a928f05fdbc76e619f4e565c312d483 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sat, 30 Nov 2019 13:53:45 -0800
Subject: [PATCH 120/279] Move core/lib/hash/hash library into core/platform

PiperOrigin-RevId: 283163718
Change-Id: Id26152d57361c0f1c513949c4819eb05c4d126d4
---
 tensorflow/core/BUILD                         |   1 +
 tensorflow/core/lib/hash/BUILD                |   7 +-
 tensorflow/core/lib/hash/hash.h               |  92 +-------------
 tensorflow/core/platform/BUILD                |  14 +++
 .../core/{lib/hash => platform}/hash.cc       |   4 +-
 tensorflow/core/platform/hash.h               | 113 ++++++++++++++++++
 6 files changed, 132 insertions(+), 99 deletions(-)
 rename tensorflow/core/{lib/hash => platform}/hash.cc (96%)
 create mode 100644 tensorflow/core/platform/hash.h

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 3839442c167..f18242e51e7 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2270,6 +2270,7 @@ cc_library(
         "//tensorflow/core/platform:errors",
         "//tensorflow/core/platform:file_statistics",
         "//tensorflow/core/platform:fingerprint",
+        "//tensorflow/core/platform:hash",
         "//tensorflow/core/platform:load_library",
         "//tensorflow/core/platform:logger",
         "//tensorflow/core/platform:mutex",
diff --git a/tensorflow/core/lib/hash/BUILD b/tensorflow/core/lib/hash/BUILD
index a44e7836cab..de2eebc785f 100644
--- a/tensorflow/core/lib/hash/BUILD
+++ b/tensorflow/core/lib/hash/BUILD
@@ -41,13 +41,9 @@ cc_library(
 
 cc_library(
     name = "hash",
-    srcs = ["hash.cc"],
     hdrs = ["hash.h"],
     deps = [
-        "//tensorflow/core/lib/core:raw_coding",
-        "//tensorflow/core/lib/core:stringpiece",
-        "//tensorflow/core/platform:macros",
-        "//tensorflow/core/platform:types",
+        "//tensorflow/core/platform:hash",
     ],
 )
 
@@ -65,7 +61,6 @@ filegroup(
     srcs = [
         "crc32c.cc",
         "crc32c_accelerate.cc",
-        "hash.cc",
     ],
     visibility = ["//tensorflow/core:__pkg__"],
 )
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
index 675bab71919..fa2cc295b15 100644
--- a/tensorflow/core/lib/hash/hash.h
+++ b/tensorflow/core/lib/hash/hash.h
@@ -18,96 +18,6 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_LIB_HASH_HASH_H_
 #define TENSORFLOW_CORE_LIB_HASH_HASH_H_
 
-#include 
-#include 
-
-#include 
-#include 
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-extern uint32 Hash32(const char* data, size_t n, uint32 seed);
-extern uint64 Hash64(const char* data, size_t n, uint64 seed);
-
-inline uint64 Hash64(const char* data, size_t n) {
-  return Hash64(data, n, 0xDECAFCAFFE);
-}
-
-inline uint64 Hash64(const string& str) {
-  return Hash64(str.data(), str.size());
-}
-
-inline uint64 Hash64Combine(uint64 a, uint64 b) {
-  return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4));
-}
-
-// Combine two hashes in an order-independent way. This operation should be
-// associative and compute the same hash for a collection of elements
-// independent of traversal order. Note that it is better to combine hashes
-// symmetrically with addition rather than XOR, since (x^x) == 0 but (x+x) != 0.
-inline uint64 Hash64CombineUnordered(uint64 a, uint64 b) { return a + b; }
-
-// Hash functor suitable for use with power-of-two sized hashtables.  Use
-// instead of std::hash.
-//
-// In particular, tensorflow::hash is not the identity function for pointers.
-// This is important for power-of-two sized hashtables like FlatMap and FlatSet,
-// because otherwise they waste the majority of their hash buckets.
-//
-// The second type argument is only used for SFNIAE below.
-template 
-struct hash {
-  size_t operator()(const T& t) const { return std::hash()(t); }
-};
-
-template 
-struct hash::value>::type> {
-  size_t operator()(T value) const {
-    // This works around a defect in the std::hash C++ spec that isn't fixed in
-    // (at least) gcc 4.8.4:
-    // http://www.open-std.org/jtc1/sc22/wg21/docs/lwg-defects.html#2148
-    //
-    // We should be able to remove this and use the default
-    // tensorflow::hash() once we stop building with GCC versions old
-    // enough to not have this defect fixed.
-    return std::hash()(static_cast(value));
-  }
-};
-
-template 
-struct hash {
-  size_t operator()(const T* t) const {
-    // Hash pointers as integers, but bring more entropy to the lower bits.
-    size_t k = static_cast(reinterpret_cast(t));
-    return k + (k >> 6);
-  }
-};
-
-template <>
-struct hash {
-  size_t operator()(const string& s) const {
-    return static_cast(Hash64(s));
-  }
-};
-
-template <>
-struct hash {
-  size_t operator()(StringPiece sp) const {
-    return static_cast(Hash64(sp.data(), sp.size()));
-  }
-};
-using StringPieceHasher = ::tensorflow::hash;
-
-template 
-struct hash> {
-  size_t operator()(const std::pair& p) const {
-    return Hash64Combine(hash()(p.first), hash()(p.second));
-  }
-};
-
-}  // namespace tensorflow
+#include "tensorflow/core/platform/hash.h"
 
 #endif  // TENSORFLOW_CORE_LIB_HASH_HASH_H_
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index ecc44e39c11..1c363209a48 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -215,6 +215,18 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "hash",
+    srcs = ["hash.cc"],
+    hdrs = ["hash.h"],
+    deps = [
+        ":macros",
+        ":raw_coding",
+        ":stringpiece",
+        ":types",
+    ],
+)
+
 cc_library(
     name = "human_readable_json",
     textual_hdrs = ["human_readable_json.h"],
@@ -735,6 +747,7 @@ filegroup(
             "**/windows_file_system.cc",
             "abi.cc",
             "cpu_info.cc",
+            "hash.cc",
             "numbers.cc",
             "path.cc",
             "platform_strings.cc",
@@ -846,6 +859,7 @@ filegroup(
             "error.cc",
             "file_system.cc",
             "file_system_helper.cc",
+            "hash.cc",
             "logger.cc",
             "numbers.cc",
             "path.cc",
diff --git a/tensorflow/core/lib/hash/hash.cc b/tensorflow/core/platform/hash.cc
similarity index 96%
rename from tensorflow/core/lib/hash/hash.cc
rename to tensorflow/core/platform/hash.cc
index dc9d300d00e..74a18f8f05e 100644
--- a/tensorflow/core/lib/hash/hash.cc
+++ b/tensorflow/core/platform/hash.cc
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/hash.h"
 
-#include "tensorflow/core/lib/core/raw_coding.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/raw_coding.h"
 #include "tensorflow/core/platform/types.h"
 
 #include 
diff --git a/tensorflow/core/platform/hash.h b/tensorflow/core/platform/hash.h
new file mode 100644
index 00000000000..3a9de99f2bc
--- /dev/null
+++ b/tensorflow/core/platform/hash.h
@@ -0,0 +1,113 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Simple hash functions used for internal data structures
+
+#ifndef TENSORFLOW_CORE_PLATFORM_HASH_H_
+#define TENSORFLOW_CORE_PLATFORM_HASH_H_
+
+#include 
+#include 
+
+#include 
+#include 
+
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+extern uint32 Hash32(const char* data, size_t n, uint32 seed);
+extern uint64 Hash64(const char* data, size_t n, uint64 seed);
+
+inline uint64 Hash64(const char* data, size_t n) {
+  return Hash64(data, n, 0xDECAFCAFFE);
+}
+
+inline uint64 Hash64(const string& str) {
+  return Hash64(str.data(), str.size());
+}
+
+inline uint64 Hash64Combine(uint64 a, uint64 b) {
+  return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4));
+}
+
+// Combine two hashes in an order-independent way. This operation should be
+// associative and compute the same hash for a collection of elements
+// independent of traversal order. Note that it is better to combine hashes
+// symmetrically with addition rather than XOR, since (x^x) == 0 but (x+x) != 0.
+inline uint64 Hash64CombineUnordered(uint64 a, uint64 b) { return a + b; }
+
+// Hash functor suitable for use with power-of-two sized hashtables.  Use
+// instead of std::hash.
+//
+// In particular, tensorflow::hash is not the identity function for pointers.
+// This is important for power-of-two sized hashtables like FlatMap and FlatSet,
+// because otherwise they waste the majority of their hash buckets.
+//
+// The second type argument is only used for SFNIAE below.
+template 
+struct hash {
+  size_t operator()(const T& t) const { return std::hash()(t); }
+};
+
+template 
+struct hash::value>::type> {
+  size_t operator()(T value) const {
+    // This works around a defect in the std::hash C++ spec that isn't fixed in
+    // (at least) gcc 4.8.4:
+    // http://www.open-std.org/jtc1/sc22/wg21/docs/lwg-defects.html#2148
+    //
+    // We should be able to remove this and use the default
+    // tensorflow::hash() once we stop building with GCC versions old
+    // enough to not have this defect fixed.
+    return std::hash()(static_cast(value));
+  }
+};
+
+template 
+struct hash {
+  size_t operator()(const T* t) const {
+    // Hash pointers as integers, but bring more entropy to the lower bits.
+    size_t k = static_cast(reinterpret_cast(t));
+    return k + (k >> 6);
+  }
+};
+
+template <>
+struct hash {
+  size_t operator()(const string& s) const {
+    return static_cast(Hash64(s));
+  }
+};
+
+template <>
+struct hash {
+  size_t operator()(StringPiece sp) const {
+    return static_cast(Hash64(sp.data(), sp.size()));
+  }
+};
+using StringPieceHasher = ::tensorflow::hash;
+
+template 
+struct hash> {
+  size_t operator()(const std::pair& p) const {
+    return Hash64Combine(hash()(p.first), hash()(p.second));
+  }
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_HASH_H_

From 9a691fd4967045f19bd5faeb5b694b50058b0a14 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sat, 30 Nov 2019 14:39:01 -0800
Subject: [PATCH 121/279] Move tensorflow/core/lib/core/coding to
 tensorflow/core/platform

PiperOrigin-RevId: 283166018
Change-Id: I7123412386eeb448ca01ea27b859292128a35cad
---
 tensorflow/core/BUILD                         |  1 +
 tensorflow/core/lib/core/BUILD                |  7 +-
 tensorflow/core/lib/core/coding.h             | 45 +-----------
 tensorflow/core/platform/BUILD                | 14 ++++
 .../core/{lib/core => platform}/coding.cc     |  2 +-
 tensorflow/core/platform/coding.h             | 69 +++++++++++++++++++
 6 files changed, 87 insertions(+), 51 deletions(-)
 rename tensorflow/core/{lib/core => platform}/coding.cc (99%)
 create mode 100644 tensorflow/core/platform/coding.h

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index f18242e51e7..492a6fef318 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2260,6 +2260,7 @@ cc_library(
         "//tensorflow/core/lib/strings:strcat",
         "//tensorflow/core/lib/strings:stringprintf",
         "//tensorflow/core/platform:abi",
+        "//tensorflow/core/platform:coding",
         "//tensorflow/core/platform:context",
         "//tensorflow/core/platform:cord",
         "//tensorflow/core/platform:cpu_feature_guard",
diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD
index baf81113029..baca47df789 100644
--- a/tensorflow/core/lib/core/BUILD
+++ b/tensorflow/core/lib/core/BUILD
@@ -53,13 +53,9 @@ cc_library(
 
 cc_library(
     name = "coding",
-    srcs = ["coding.cc"],
     hdrs = ["coding.h"],
     deps = [
-        "//tensorflow/core/lib/core:raw_coding",
-        "//tensorflow/core/lib/core:stringpiece",
-        "//tensorflow/core/platform:byte_order",
-        "//tensorflow/core/platform:types",
+        "//tensorflow/core/platform:coding",
     ],
 )
 
@@ -172,7 +168,6 @@ filegroup(
     srcs = [
         "arena.cc",
         "bitmap.cc",
-        "coding.cc",
     ],
     visibility = ["//tensorflow/core:__pkg__"],
 )
diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h
index bfab80dd007..a1121c888dd 100644
--- a/tensorflow/core/lib/core/coding.h
+++ b/tensorflow/core/lib/core/coding.h
@@ -21,49 +21,6 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_LIB_CORE_CODING_H_
 #define TENSORFLOW_CORE_LIB_CORE_CODING_H_
 
-#include "tensorflow/core/lib/core/raw_coding.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace core {
-
-// Maximum number of bytes occupied by a varint32.
-static const int kMaxVarint32Bytes = 5;
-
-// Maximum number of bytes occupied by a varint64.
-static const int kMaxVarint64Bytes = 10;
-
-// Lower-level versions of Put... that write directly into a character buffer
-// REQUIRES: dst has enough space for the value being written
-extern void EncodeFixed16(char* dst, uint16 value);
-extern void EncodeFixed32(char* dst, uint32 value);
-extern void EncodeFixed64(char* dst, uint64 value);
-extern void PutFixed16(string* dst, uint16 value);
-extern void PutFixed32(string* dst, uint32 value);
-extern void PutFixed64(string* dst, uint64 value);
-
-extern void PutVarint32(string* dst, uint32 value);
-extern void PutVarint64(string* dst, uint64 value);
-
-extern bool GetVarint32(StringPiece* input, uint32* value);
-extern bool GetVarint64(StringPiece* input, uint64* value);
-
-extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v);
-extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v);
-
-// Internal routine for use by fallback path of GetVarint32Ptr
-extern const char* GetVarint32PtrFallback(const char* p, const char* limit,
-                                          uint32* value);
-extern const char* GetVarint32Ptr(const char* p, const char* limit,
-                                  uint32* value);
-extern char* EncodeVarint32(char* dst, uint32 v);
-extern char* EncodeVarint64(char* dst, uint64 v);
-
-// Returns the length of the varint32 or varint64 encoding of "v"
-extern int VarintLength(uint64_t v);
-
-}  // namespace core
-}  // namespace tensorflow
+#include "tensorflow/core/platform/coding.h"
 
 #endif  // TENSORFLOW_CORE_LIB_CORE_CODING_H_
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 1c363209a48..e800684fe53 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -88,6 +88,18 @@ cc_library(
     hdrs = ["byte_order.h"],
 )
 
+cc_library(
+    name = "coding",
+    srcs = ["coding.cc"],
+    hdrs = ["coding.h"],
+    deps = [
+        ":byte_order",
+        ":raw_coding",
+        ":stringpiece",
+        ":types",
+    ],
+)
+
 cc_library(
     name = "context",
     textual_hdrs = ["context.h"],
@@ -746,6 +758,7 @@ filegroup(
             "**/unbounded_work_queue.cc",
             "**/windows_file_system.cc",
             "abi.cc",
+            "coding.cc",
             "cpu_info.cc",
             "hash.cc",
             "numbers.cc",
@@ -852,6 +865,7 @@ filegroup(
             "**/human_readable_json.cc",
             "**/rocm_rocdl_path.cc",
             "abi.cc",
+            "coding.cc",
             "cpu_info.cc",
             "cpu_feature_guard.cc",
             "denormal.cc",
diff --git a/tensorflow/core/lib/core/coding.cc b/tensorflow/core/platform/coding.cc
similarity index 99%
rename from tensorflow/core/lib/core/coding.cc
rename to tensorflow/core/platform/coding.cc
index 4c33dfa211e..ef0df8fa42a 100644
--- a/tensorflow/core/lib/core/coding.cc
+++ b/tensorflow/core/platform/coding.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/platform/coding.h"
 
 #include "tensorflow/core/platform/byte_order.h"
 
diff --git a/tensorflow/core/platform/coding.h b/tensorflow/core/platform/coding.h
new file mode 100644
index 00000000000..cd66e54dfdb
--- /dev/null
+++ b/tensorflow/core/platform/coding.h
@@ -0,0 +1,69 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Endian-neutral encoding:
+// * Fixed-length numbers are encoded with least-significant byte first
+// * In addition we support variable length "varint" encoding
+// * Strings are encoded prefixed by their length in varint format
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CODING_H_
+#define TENSORFLOW_CORE_PLATFORM_CODING_H_
+
+#include "tensorflow/core/platform/raw_coding.h"
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace core {
+
+// Maximum number of bytes occupied by a varint32.
+static const int kMaxVarint32Bytes = 5;
+
+// Maximum number of bytes occupied by a varint64.
+static const int kMaxVarint64Bytes = 10;
+
+// Lower-level versions of Put... that write directly into a character buffer
+// REQUIRES: dst has enough space for the value being written
+extern void EncodeFixed16(char* dst, uint16 value);
+extern void EncodeFixed32(char* dst, uint32 value);
+extern void EncodeFixed64(char* dst, uint64 value);
+extern void PutFixed16(string* dst, uint16 value);
+extern void PutFixed32(string* dst, uint32 value);
+extern void PutFixed64(string* dst, uint64 value);
+
+extern void PutVarint32(string* dst, uint32 value);
+extern void PutVarint64(string* dst, uint64 value);
+
+extern bool GetVarint32(StringPiece* input, uint32* value);
+extern bool GetVarint64(StringPiece* input, uint64* value);
+
+extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v);
+extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v);
+
+// Internal routine for use by fallback path of GetVarint32Ptr
+extern const char* GetVarint32PtrFallback(const char* p, const char* limit,
+                                          uint32* value);
+extern const char* GetVarint32Ptr(const char* p, const char* limit,
+                                  uint32* value);
+extern char* EncodeVarint32(char* dst, uint32 v);
+extern char* EncodeVarint64(char* dst, uint64 v);
+
+// Returns the length of the varint32 or varint64 encoding of "v"
+extern int VarintLength(uint64_t v);
+
+}  // namespace core
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_CODING_H_

From c4dae6499d32806823caf86275393f531d7c1ce0 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sat, 30 Nov 2019 15:18:13 -0800
Subject: [PATCH 122/279] Remove dependence on core/lib/core/refcount under
 tensorflow/core/platform.

Library has been moved to platform/refcount.h/cc

PiperOrigin-RevId: 283167865
Change-Id: I84489ff9603cf010605c6f3b6be6c4dcabf10183
---
 tensorflow/core/platform/BUILD           | 2 +-
 tensorflow/core/platform/tensor_coding.h | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index e800684fe53..8649bb996c1 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -601,11 +601,11 @@ cc_library(
     deps = [
         ":platform",
         ":protobuf",
+        ":refcount",
         ":stringpiece",
         ":strcat",
         ":types",
         "//tensorflow/core/lib/core:coding",
-        "//tensorflow/core/lib/core:refcount",
     ] + tf_additional_tensor_coding_deps(),
 )
 
diff --git a/tensorflow/core/platform/tensor_coding.h b/tensorflow/core/platform/tensor_coding.h
index 63e47a880a9..fcfa5469e18 100644
--- a/tensorflow/core/platform/tensor_coding.h
+++ b/tensorflow/core/platform/tensor_coding.h
@@ -19,9 +19,9 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/types.h"
 

From 0654c7829baabdc6ac390c63ca922eb2c233a4c0 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sat, 30 Nov 2019 15:49:44 -0800
Subject: [PATCH 123/279] Remove dependencies on core/lib/hash under
 tensorflow/core/platform.

The library has moved to core/platform.

PiperOrigin-RevId: 283169462
Change-Id: I409fabd04bd5254ff94d3129ebcd399a5cbc15ab
---
 tensorflow/core/platform/default/build_refactor.bzl | 1 +
 tensorflow/core/platform/tracing.cc                 | 2 +-
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/platform/default/build_refactor.bzl b/tensorflow/core/platform/default/build_refactor.bzl
index 5dbada0a08e..303a22cdde2 100644
--- a/tensorflow/core/platform/default/build_refactor.bzl
+++ b/tensorflow/core/platform/default/build_refactor.bzl
@@ -342,6 +342,7 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
         "deps": [
             "//tensorflow/core/lib/hash",
             "//tensorflow/core/platform",
+            "//tensorflow/core/platform:hash",
             "//tensorflow/core/platform:logging",
             "//tensorflow/core/platform:macros",
             "//tensorflow/core/platform:strcat",
diff --git a/tensorflow/core/platform/tracing.cc b/tensorflow/core/platform/tracing.cc
index 30aa664ae01..a7745903d4b 100644
--- a/tensorflow/core/platform/tracing.cc
+++ b/tensorflow/core/platform/tracing.cc
@@ -18,7 +18,7 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/hash.h"
 
 namespace tensorflow {
 namespace tracing {

From bd754067dac90182d883f621b775d76ec7c6b87d Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sat, 30 Nov 2019 15:57:26 -0800
Subject: [PATCH 124/279] Remove dependence on tensorflow/core/lib/core/coding
 under tensorflow/core/platform

That library has been moved to tensorflow/core/platform

PiperOrigin-RevId: 283169817
Change-Id: Id9990117edbea7795c8de94ca9e5b8ddb3ce4101
---
 tensorflow/core/platform/BUILD            | 2 +-
 tensorflow/core/platform/tensor_coding.cc | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 8649bb996c1..d4ee72cc11d 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -599,13 +599,13 @@ cc_library(
     srcs = ["tensor_coding.cc"],
     hdrs = ["tensor_coding.h"],
     deps = [
+        ":coding",
         ":platform",
         ":protobuf",
         ":refcount",
         ":stringpiece",
         ":strcat",
         ":types",
-        "//tensorflow/core/lib/core:coding",
     ] + tf_additional_tensor_coding_deps(),
 )
 
diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc
index f115da2b4d6..c12810a42d6 100644
--- a/tensorflow/core/platform/tensor_coding.cc
+++ b/tensorflow/core/platform/tensor_coding.cc
@@ -17,7 +17,7 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/platform/coding.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/strcat.h"
 #include "tensorflow/core/platform/stringpiece.h"

From dfc484e8997fe8b23f3f17a01c8fa0cfe230514b Mon Sep 17 00:00:00 2001
From: hsahovic 
Date: Sun, 1 Dec 2019 01:37:18 -0500
Subject: [PATCH 125/279] Fix misformatted markdown in Model.fit docstring

The current documentation of `tf.keras.models.Model.fit` contains misformatted markdown in the description of the `validation_data` attribute: it is missing a line break after the list enumerating the possible types of accepted arguments, which hurts legibility upon rendering.

See [keras documentation](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit).

This commit fixes that problem.
---
 tensorflow/python/keras/engine/training.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 40fdb0c79a3..36570e36cc8 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -676,6 +676,7 @@ class Model(network.Network, version_utils.VersionSelector):
               - tuple `(x_val, y_val)` of Numpy arrays or tensors
               - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
               - dataset
+              
             For the first two cases, `batch_size` must be provided.
             For the last case, `validation_steps` could be provided.
         shuffle: Boolean (whether to shuffle the training data

From 908d137c358600d3c0c40435219734f8339930eb Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sat, 30 Nov 2019 16:29:50 -0800
Subject: [PATCH 126/279] Replace the core/lib/io/path usages within
 core/platform.

This library has moved into core/platform.

PiperOrigin-RevId: 283171725
Change-Id: I3e8468254189dcafd7937fe16b14e428ad27cc95
---
 tensorflow/core/BUILD                           |  1 +
 tensorflow/core/platform/cloud/BUILD            |  6 ++++++
 .../platform/cloud/curl_http_request_test.cc    |  4 +++-
 .../core/platform/cloud/gcs_file_system.cc      |  2 +-
 .../core/platform/cloud/google_auth_provider.cc |  2 +-
 .../platform/cloud/google_auth_provider_test.cc |  4 +++-
 .../core/platform/cloud/oauth_client_test.cc    |  2 +-
 .../core/platform/default/build_refactor.bzl    | 17 ++++++-----------
 .../core/platform/default/rocm_rocdl_path.cc    |  2 +-
 tensorflow/core/platform/env_test.cc            |  2 +-
 tensorflow/core/platform/file_system_test.cc    |  2 +-
 tensorflow/core/platform/hadoop/BUILD           |  2 ++
 .../core/platform/hadoop/hadoop_file_system.cc  |  2 +-
 .../platform/hadoop/hadoop_file_system_test.cc  |  2 +-
 .../core/platform/platform_strings_test.cc      |  2 +-
 .../core/platform/rocm_rocdl_path_test.cc       |  2 +-
 tensorflow/core/platform/s3/BUILD               |  2 ++
 tensorflow/core/platform/s3/s3_file_system.cc   |  2 +-
 .../core/platform/s3/s3_file_system_test.cc     |  2 +-
 .../core/platform/windows/windows_file_system.h |  2 +-
 20 files changed, 36 insertions(+), 26 deletions(-)

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 492a6fef318..6ab27d6644d 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -396,6 +396,7 @@ filegroup(
         "//tensorflow/core/platform:env.h",
         "//tensorflow/core/platform:file_statistics.h",
         "//tensorflow/core/platform:file_system.h",
+        "//tensorflow/core/platform:path.h",
     ],
     visibility = ["//visibility:private"],
 )
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index a4019273fc9..1ad3d06f5bb 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -100,6 +100,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/platform:numbers",
+        "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
@@ -135,6 +136,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/platform:numbers",
+        "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
@@ -203,6 +205,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:status",
         "@com_google_absl//absl/strings",
         "@jsoncpp_git//:jsoncpp",
@@ -401,6 +404,7 @@ tf_cc_test(
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:path",
     ],
 )
 
@@ -419,6 +423,7 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:scanner",
         "@boringssl//:crypto",
     ],
@@ -440,6 +445,7 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:path",
     ],
 )
 
diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc
index e31901a7a0f..754f3e4b4b9 100644
--- a/tensorflow/core/platform/cloud/curl_http_request_test.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc
@@ -14,10 +14,12 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
+
 #include 
+
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index c55d4ef257e..b6b988047c8 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -28,7 +28,6 @@ limitations under the License.
 #include "absl/base/macros.h"
 #include "include/json/json.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/cloud/file_block_cache.h"
 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
@@ -39,6 +38,7 @@ limitations under the License.
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/numbers.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/stringprintf.h"
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc
index bb52a5a7ca7..264cb041f77 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider.cc
@@ -25,11 +25,11 @@ limitations under the License.
 
 #include "absl/strings/match.h"
 #include "include/json/json.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/base64.h"
 #include "tensorflow/core/platform/cloud/retrying_utils.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/path.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
index 8c7e107037a..5bee2072034 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
@@ -14,10 +14,12 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
+
 #include 
+
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/cloud/http_request_fake.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index ca24365434b..1b04b1cf827 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -21,10 +21,10 @@ limitations under the License.
 #include 
 #include 
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/base64.h"
 #include "tensorflow/core/platform/cloud/http_request_fake.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/scanner.h"
 #include "tensorflow/core/platform/test.h"
 
diff --git a/tensorflow/core/platform/default/build_refactor.bzl b/tensorflow/core/platform/default/build_refactor.bzl
index 303a22cdde2..6d1beca6923 100644
--- a/tensorflow/core/platform/default/build_refactor.bzl
+++ b/tensorflow/core/platform/default/build_refactor.bzl
@@ -39,11 +39,9 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
         ],
         "deps": [
             "@local_config_cuda//cuda:cuda_headers",
-            "//tensorflow/core:lib",
-            # TODO(bmzhao): When bazel gains cc_shared_library support, the targets below are
-            # the actual granular targets we should depend on, instead of tf/core:lib.
-            # "//tensorflow/core/platform:logging",
-            # "//tensorflow/core/platform:types",
+            "//tensorflow/core/platform:logging",
+            "//tensorflow/core/platform:path",
+            "//tensorflow/core/platform:types",
         ],
         "visibility": ["//visibility:private"],
         "tags": ["no_oss", "manual"],
@@ -236,12 +234,9 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
         ],
         "deps": [
             "@local_config_rocm//rocm:rocm_headers",
-            "//tensorflow/core:lib",
-            # TODO(bmzhao): When bazel gains cc_shared_library support, the targets below are
-            # the actual granular targets we should depend on, instead of tf/core:lib.
-            # "//tensorflow/core/lib/io:path",
-            # "//tensorflow/core/platform:logging",
-            # "//tensorflow/core/platform:types",
+            "//tensorflow/core/platform:path",
+            "//tensorflow/core/platform:logging",
+            "//tensorflow/core/platform:types",
         ],
         "visibility": ["//visibility:private"],
         "tags": ["no_oss", "manual"],
diff --git a/tensorflow/core/platform/default/rocm_rocdl_path.cc b/tensorflow/core/platform/default/rocm_rocdl_path.cc
index 0831544f616..55075969cbd 100644
--- a/tensorflow/core/platform/default/rocm_rocdl_path.cc
+++ b/tensorflow/core/platform/default/rocm_rocdl_path.cc
@@ -17,7 +17,7 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/path.h"
 
 #if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM
 #include "rocm/rocm_config.h"
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index bee02dbeeed..1f4bd7c6a79 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -20,9 +20,9 @@ limitations under the License.
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/cord.h"
 #include "tensorflow/core/platform/null_file_system.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/strcat.h"
diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc
index 8b577c37c75..278561f4f0d 100644
--- a/tensorflow/core/platform/file_system_test.cc
+++ b/tensorflow/core/platform/file_system_test.cc
@@ -18,8 +18,8 @@ limitations under the License.
 #include 
 
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/null_file_system.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/strcat.h"
 #include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/platform/hadoop/BUILD b/tensorflow/core/platform/hadoop/BUILD
index dc42b901b62..49d9e9975cf 100644
--- a/tensorflow/core/platform/hadoop/BUILD
+++ b/tensorflow/core/platform/hadoop/BUILD
@@ -18,6 +18,7 @@ cc_library(
     deps = [
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:strcat",
         "//third_party/hadoop:hdfs",
     ],
@@ -59,6 +60,7 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:str_util",
     ],
 )
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 6a5c4115189..34dc1cf305b 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -17,13 +17,13 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/error.h"
 #include "tensorflow/core/platform/file_system.h"
 #include "tensorflow/core/platform/file_system_helper.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/strcat.h"
 #include "third_party/hadoop/hdfs.h"
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
index 0c21e9662ee..3104addc4e0 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/platform/hadoop/hadoop_file_system.h"
 
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
diff --git a/tensorflow/core/platform/platform_strings_test.cc b/tensorflow/core/platform/platform_strings_test.cc
index 3824ff550f3..a4af143d58f 100644
--- a/tensorflow/core/platform/platform_strings_test.cc
+++ b/tensorflow/core/platform/platform_strings_test.cc
@@ -25,10 +25,10 @@ limitations under the License.
 #include 
 #include 
 
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/init_main.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/str_util.h"
 
 // Embed the platform strings in this binary.
diff --git a/tensorflow/core/platform/rocm_rocdl_path_test.cc b/tensorflow/core/platform/rocm_rocdl_path_test.cc
index 4a4d9b89c59..3436dafac6d 100644
--- a/tensorflow/core/platform/rocm_rocdl_path_test.cc
+++ b/tensorflow/core/platform/rocm_rocdl_path_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/platform/rocm_rocdl_path.h"
 
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD
index d22a759c440..a5494d5c318 100644
--- a/tensorflow/core/platform/s3/BUILD
+++ b/tensorflow/core/platform/s3/BUILD
@@ -84,6 +84,7 @@ cc_library(
         ":aws_logging",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:str_util",
         "@aws",
     ],
@@ -105,6 +106,7 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:path",
         "@aws",
     ],
 )
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 8c821faa651..936339079cf 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -32,9 +32,9 @@ limitations under the License.
 
 #include 
 
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/file_system_helper.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/s3/aws_crypto.h"
 #include "tensorflow/core/platform/s3/aws_logging.h"
 #include "tensorflow/core/platform/str_util.h"
diff --git a/tensorflow/core/platform/s3/s3_file_system_test.cc b/tensorflow/core/platform/s3/s3_file_system_test.cc
index e7c3e4a8904..98778495f47 100644
--- a/tensorflow/core/platform/s3/s3_file_system_test.cc
+++ b/tensorflow/core/platform/s3/s3_file_system_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/platform/s3/s3_file_system.h"
 
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/windows/windows_file_system.h b/tensorflow/core/platform/windows/windows_file_system.h
index 255f6d59a6f..2e0de725762 100644
--- a/tensorflow/core/platform/windows/windows_file_system.h
+++ b/tensorflow/core/platform/windows/windows_file_system.h
@@ -16,8 +16,8 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
 #define TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
 
-#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/platform.h"
 
 #ifdef PLATFORM_WINDOWS

From 0482a363151fb349f5c39b7d7bdb096c369b6dba Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Sun, 1 Dec 2019 01:03:35 -0800
Subject: [PATCH 127/279] compat: Update forward compatibility horizon to
 2019-12-01

PiperOrigin-RevId: 283199749
Change-Id: I0bd96706c26c0b27459691dabd94f3a4d0a0c51d
---
 tensorflow/python/compat/compat.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 65718014f30..14c7186d7b4 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 11, 30)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 1)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 

From 94f84edd58440cd9145340140232fc9e559b5e02 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sun, 1 Dec 2019 13:43:50 -0800
Subject: [PATCH 128/279] Move blocking counter library from core/lib/core to
 core/platform

PiperOrigin-RevId: 283241034
Change-Id: I7397ccbba64a88ff78ed97ef8a5cc4d87e5397dc
---
 tensorflow/core/BUILD                       |  1 +
 tensorflow/core/lib/core/BUILD              |  3 +-
 tensorflow/core/lib/core/blocking_counter.h | 61 +---------------
 tensorflow/core/platform/BUILD              |  9 +++
 tensorflow/core/platform/blocking_counter.h | 80 +++++++++++++++++++++
 5 files changed, 92 insertions(+), 62 deletions(-)
 create mode 100644 tensorflow/core/platform/blocking_counter.h

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6ab27d6644d..06b416fbdaa 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2261,6 +2261,7 @@ cc_library(
         "//tensorflow/core/lib/strings:strcat",
         "//tensorflow/core/lib/strings:stringprintf",
         "//tensorflow/core/platform:abi",
+        "//tensorflow/core/platform:blocking_counter",
         "//tensorflow/core/platform:coding",
         "//tensorflow/core/platform:context",
         "//tensorflow/core/platform:cord",
diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD
index baca47df789..a3ed21f8771 100644
--- a/tensorflow/core/lib/core/BUILD
+++ b/tensorflow/core/lib/core/BUILD
@@ -37,8 +37,7 @@ cc_library(
     name = "blocking_counter",
     hdrs = ["blocking_counter.h"],
     deps = [
-        "//tensorflow/core/platform:logging",
-        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:blocking_counter",
     ],
 )
 
diff --git a/tensorflow/core/lib/core/blocking_counter.h b/tensorflow/core/lib/core/blocking_counter.h
index 5dab07dbef9..8355a7ac870 100644
--- a/tensorflow/core/lib/core/blocking_counter.h
+++ b/tensorflow/core/lib/core/blocking_counter.h
@@ -16,65 +16,6 @@ limitations under the License.
 #ifndef TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
 #define TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
 
-#include 
-
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
-
-namespace tensorflow {
-
-class BlockingCounter {
- public:
-  BlockingCounter(int initial_count)
-      : state_(initial_count << 1), notified_(false) {
-    CHECK_GE(initial_count, 0);
-    DCHECK_EQ((initial_count << 1) >> 1, initial_count);
-  }
-
-  ~BlockingCounter() {}
-
-  inline void DecrementCount() {
-    unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2;
-    if (v != 1) {
-      DCHECK_NE(((v + 2) & ~1), 0);
-      return;  // either count has not dropped to 0, or waiter is not waiting
-    }
-    mutex_lock l(mu_);
-    DCHECK(!notified_);
-    notified_ = true;
-    cond_var_.notify_all();
-  }
-
-  inline void Wait() {
-    unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
-    if ((v >> 1) == 0) return;
-    mutex_lock l(mu_);
-    while (!notified_) {
-      cond_var_.wait(l);
-    }
-  }
-  // Wait for the specified time, return false iff the count has not dropped to
-  // zero before the timeout expired.
-  inline bool WaitFor(std::chrono::milliseconds ms) {
-    unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
-    if ((v >> 1) == 0) return true;
-    mutex_lock l(mu_);
-    while (!notified_) {
-      const std::cv_status status = cond_var_.wait_for(l, ms);
-      if (status == std::cv_status::timeout) {
-        return false;
-      }
-    }
-    return true;
-  }
-
- private:
-  mutex mu_;
-  condition_variable cond_var_;
-  std::atomic state_;  // low bit is waiter flag
-  bool notified_;
-};
-
-}  // namespace tensorflow
+#include "tensorflow/core/platform/blocking_counter.h"
 
 #endif  // TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index d4ee72cc11d..2cfe0c31d23 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -83,6 +83,15 @@ cc_library(
     deps = [":types"],
 )
 
+cc_library(
+    name = "blocking_counter",
+    hdrs = ["blocking_counter.h"],
+    deps = [
+        ":logging",
+        ":mutex",
+    ],
+)
+
 cc_library(
     name = "byte_order",
     hdrs = ["byte_order.h"],
diff --git a/tensorflow/core/platform/blocking_counter.h b/tensorflow/core/platform/blocking_counter.h
new file mode 100644
index 00000000000..9e7ca004024
--- /dev/null
+++ b/tensorflow/core/platform/blocking_counter.h
@@ -0,0 +1,80 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_
+#define TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_
+
+#include 
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+class BlockingCounter {
+ public:
+  BlockingCounter(int initial_count)
+      : state_(initial_count << 1), notified_(false) {
+    CHECK_GE(initial_count, 0);
+    DCHECK_EQ((initial_count << 1) >> 1, initial_count);
+  }
+
+  ~BlockingCounter() {}
+
+  inline void DecrementCount() {
+    unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2;
+    if (v != 1) {
+      DCHECK_NE(((v + 2) & ~1), 0);
+      return;  // either count has not dropped to 0, or waiter is not waiting
+    }
+    mutex_lock l(mu_);
+    DCHECK(!notified_);
+    notified_ = true;
+    cond_var_.notify_all();
+  }
+
+  inline void Wait() {
+    unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
+    if ((v >> 1) == 0) return;
+    mutex_lock l(mu_);
+    while (!notified_) {
+      cond_var_.wait(l);
+    }
+  }
+  // Wait for the specified time, return false iff the count has not dropped to
+  // zero before the timeout expired.
+  inline bool WaitFor(std::chrono::milliseconds ms) {
+    unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
+    if ((v >> 1) == 0) return true;
+    mutex_lock l(mu_);
+    while (!notified_) {
+      const std::cv_status status = cond_var_.wait_for(l, ms);
+      if (status == std::cv_status::timeout) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+ private:
+  mutex mu_;
+  condition_variable cond_var_;
+  std::atomic state_;  // low bit is waiter flag
+  bool notified_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_

From 5fcf0482c48d3ba38570bc7a522106c1f2deeeaf Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Sun, 1 Dec 2019 14:57:08 -0800
Subject: [PATCH 129/279] Remove the "tensorflow/core/framework_*_pyclif"
 aliases. Users should use the "tensorflow/core/framework:*_pyclif" targets
 directly.

PiperOrigin-RevId: 283245713
Change-Id: I960f28ecf1bb1b95ab01b00910e824cfab885a41
---
 tensorflow/core/BUILD | 26 --------------------------
 1 file changed, 26 deletions(-)

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 06b416fbdaa..6a810de58b0 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1977,32 +1977,6 @@ tf_pyclif_proto_library(
     ]
 ]
 
-# The following targets were moved to core/framework. The aliases are only temporary
-# since moving existing users will require several CLs over several projects.
-
-[
-    alias(
-        name = "framework_%s_pyclif%s" % (proto_name, target_suffix),
-        actual = "//tensorflow/core/framework:%s_pyclif%s" % (proto_name, target_suffix),
-        visibility = ["//visibility:public"],
-    )
-    for target_suffix in [
-        "",
-        "_pb2",
-    ]
-    for proto_name in [
-        "cost_graph",
-        "tensor",
-        "kernel_def",
-        "node_def",
-        "function",
-        "graph",
-        "step_stats",
-        "types",
-        "variable",
-    ]
-]
-
 # -----------------------------------------------------------------------------
 # Internal targets
 

From b82ab06d4b596a4aa964b2c78de3478e0cef6dd9 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Sun, 1 Dec 2019 15:53:21 -0800
Subject: [PATCH 130/279] Move base64 from core/lib/strings to core/platform

PiperOrigin-RevId: 283249015
Change-Id: I7600a1743d9c9680b8d1960cde96a1c75cc2cdcc
---
 tensorflow/core/lib/strings/BUILD             |  5 +-
 tensorflow/core/lib/strings/base64.h          | 39 +------------
 tensorflow/core/platform/BUILD                | 10 ++++
 .../core/{lib/strings => platform}/base64.cc  |  4 +-
 tensorflow/core/platform/base64.h             | 58 +++++++++++++++++++
 5 files changed, 72 insertions(+), 44 deletions(-)
 rename tensorflow/core/{lib/strings => platform}/base64.cc (98%)
 create mode 100644 tensorflow/core/platform/base64.h

diff --git a/tensorflow/core/lib/strings/BUILD b/tensorflow/core/lib/strings/BUILD
index 598a8bc5a47..31425aabc10 100644
--- a/tensorflow/core/lib/strings/BUILD
+++ b/tensorflow/core/lib/strings/BUILD
@@ -14,11 +14,9 @@ package(
 
 cc_library(
     name = "base64",
-    srcs = ["base64.cc"],
     hdrs = ["base64.h"],
     deps = [
-        "//tensorflow/core/lib/core:errors",
-        "//tensorflow/core/lib/core:status",
+        "//tensorflow/core/platform:base64",
     ],
 )
 
@@ -113,7 +111,6 @@ filegroup(
 filegroup(
     name = "legacy_lib_strings_all_srcs",
     srcs = [
-        "base64.cc",
         "ordered_code.cc",
         "proto_serialization.cc",
         "proto_text_util.cc",
diff --git a/tensorflow/core/lib/strings/base64.h b/tensorflow/core/lib/strings/base64.h
index 15a273b36a9..bb7cbfb3777 100644
--- a/tensorflow/core/lib/strings/base64.h
+++ b/tensorflow/core/lib/strings/base64.h
@@ -16,43 +16,6 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_
 #define TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_
 
-#include 
-#include "tensorflow/core/lib/core/status.h"
-
-namespace tensorflow {
-
-/// \brief Converts data into web-safe base64 encoding.
-///
-/// See https://en.wikipedia.org/wiki/Base64
-template 
-Status Base64Encode(StringPiece source, bool with_padding, T* encoded);
-template 
-Status Base64Encode(StringPiece source,
-                    T* encoded);  // with_padding=false.
-
-/// \brief Converts data from web-safe base64 encoding.
-///
-/// See https://en.wikipedia.org/wiki/Base64
-template 
-Status Base64Decode(StringPiece data, T* decoded);
-
-// Explicit instantiations defined in base64.cc.
-extern template Status Base64Decode(StringPiece data, string* decoded);
-extern template Status Base64Encode(StringPiece source,
-                                            string* encoded);
-extern template Status Base64Encode(StringPiece source,
-                                            bool with_padding, string* encoded);
-
-#ifdef USE_TSTRING
-extern template Status Base64Decode(StringPiece data,
-                                             tstring* decoded);
-extern template Status Base64Encode(StringPiece source,
-                                             tstring* encoded);
-extern template Status Base64Encode(StringPiece source,
-                                             bool with_padding,
-                                             tstring* encoded);
-#endif  // USE_TSTRING
-
-}  // namespace tensorflow
+#include "tensorflow/core/platform/base64.h"
 
 #endif  // TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 2cfe0c31d23..534a3b4e7cc 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -83,6 +83,16 @@ cc_library(
     deps = [":types"],
 )
 
+cc_library(
+    name = "base64",
+    srcs = ["base64.cc"],
+    hdrs = ["base64.h"],
+    deps = [
+        ":errors",
+        ":status",
+    ],
+)
+
 cc_library(
     name = "blocking_counter",
     hdrs = ["blocking_counter.h"],
diff --git a/tensorflow/core/lib/strings/base64.cc b/tensorflow/core/platform/base64.cc
similarity index 98%
rename from tensorflow/core/lib/strings/base64.cc
rename to tensorflow/core/platform/base64.cc
index 80eec3a9403..0ff690f1b32 100644
--- a/tensorflow/core/lib/strings/base64.cc
+++ b/tensorflow/core/platform/base64.cc
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/lib/strings/base64.h"
+#include "tensorflow/core/platform/base64.h"
 
 #include 
 #include 
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/platform/base64.h b/tensorflow/core/platform/base64.h
new file mode 100644
index 00000000000..7b764732dc9
--- /dev/null
+++ b/tensorflow/core/platform/base64.h
@@ -0,0 +1,58 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_BASE64_H_
+#define TENSORFLOW_CORE_PLATFORM_BASE64_H_
+
+#include 
+#include "tensorflow/core/platform/status.h"
+
+namespace tensorflow {
+
+/// \brief Converts data into web-safe base64 encoding.
+///
+/// See https://en.wikipedia.org/wiki/Base64
+template 
+Status Base64Encode(StringPiece source, bool with_padding, T* encoded);
+template 
+Status Base64Encode(StringPiece source,
+                    T* encoded);  // with_padding=false.
+
+/// \brief Converts data from web-safe base64 encoding.
+///
+/// See https://en.wikipedia.org/wiki/Base64
+template 
+Status Base64Decode(StringPiece data, T* decoded);
+
+// Explicit instantiations defined in base64.cc.
+extern template Status Base64Decode(StringPiece data, string* decoded);
+extern template Status Base64Encode(StringPiece source,
+                                            string* encoded);
+extern template Status Base64Encode(StringPiece source,
+                                            bool with_padding, string* encoded);
+
+#ifdef USE_TSTRING
+extern template Status Base64Decode(StringPiece data,
+                                             tstring* decoded);
+extern template Status Base64Encode(StringPiece source,
+                                             tstring* encoded);
+extern template Status Base64Encode(StringPiece source,
+                                             bool with_padding,
+                                             tstring* encoded);
+#endif  // USE_TSTRING
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_BASE64_H_

From ab1bf04841308e0276ab799d6e5474395d640fa1 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 00:58:50 -0800
Subject: [PATCH 131/279] Allows disabling delegation to NNAPI CPU on Android
 10 when no accelerator name is specified. NNAPI CPU typically performs worse
 than TfLite on its own; but allowing CPU enables partial acceleration of a
 model. Specifying this behaviour is not possible before Android 10.

PiperOrigin-RevId: 283290356
Change-Id: Ie58cdb1391082584923b6329dc96b05d92decaaf
---
 tensorflow/lite/delegates/nnapi/BUILD         |  24 +++
 .../lite/delegates/nnapi/nnapi_delegate.cc    |  41 +++-
 .../lite/delegates/nnapi/nnapi_delegate.h     |   8 +
 .../nnapi_delegate_device_selection_test.cc   | 190 ++++++++++++++++++
 .../delegates/nnapi/nnapi_delegate_kernel.h   |   2 +-
 .../nnapi/nnapi_delegate_mock_test.h          |   2 +-
 6 files changed, 257 insertions(+), 10 deletions(-)
 create mode 100644 tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc

diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD
index 0e99e1e3b79..54251676da3 100644
--- a/tensorflow/lite/delegates/nnapi/BUILD
+++ b/tensorflow/lite/delegates/nnapi/BUILD
@@ -157,6 +157,30 @@ cc_test(
     ],
 )
 
+cc_test(
+    name = "nnapi_delegate_device_selection_test",
+    size = "small",
+    srcs = [
+        "nnapi_delegate_device_selection_test.cc",
+    ],
+    tags = [
+        "no_mac",
+        "no_windows",
+        "tflite_not_portable_ios",
+    ],
+    deps = [
+        ":nnapi_delegate",
+        ":nnapi_delegate_mock_test",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite:minimal_logging",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/kernels:test_util",
+        "//tensorflow/lite/nnapi:nnapi_implementation",
+        "//tensorflow/lite/nnapi:nnapi_lib",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_test(
     name = "quant_lstm_sup_test",
     size = "small",
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index 3e4967aebfc..cc73f3020e5 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -2873,12 +2873,34 @@ TfLiteStatus NNAPIDelegateKernel::Init(TfLiteContext* context,
   const auto delegate_options =
       StatefulNnApiDelegate::GetOptions(params->delegate);
   const char* device_name_ptr = delegate_options.accelerator_name;
-  // user specified an acclelerator to use.
-  if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-      device_name_ptr != nullptr) {
-    nnapi_device_ = GetDeviceHandle(context, device_name_ptr);
-    if (nnapi_device_ == nullptr) {
-      return kTfLiteError;
+  if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12) {
+    if (device_name_ptr != nullptr) {
+      // User specified an accelerator to use.
+      ANeuralNetworksDevice* nnapi_device =
+          GetDeviceHandle(context, device_name_ptr);
+      if (nnapi_device == nullptr) {
+        return kTfLiteError;
+      }
+      nnapi_devices_.push_back(nnapi_device);
+    } else if (delegate_options.disallow_nnapi_cpu) {
+      std::string nnapi_cpu("nnapi-reference");
+      uint32_t num_devices = 0;
+      NnApiImplementation()->ANeuralNetworks_getDeviceCount(&num_devices);
+
+      for (uint32_t i = 0; i < num_devices; i++) {
+        ANeuralNetworksDevice* device = nullptr;
+        const char* buffer = nullptr;
+        NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
+        NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
+        if (nnapi_cpu != buffer) {
+          nnapi_devices_.push_back(device);
+        }
+      }
+      if (nnapi_devices_.empty()) {
+        context->ReportError(
+            context, "NNAPI delegate requested but no accelerators available.");
+        return kTfLiteError;
+      }
     }
   }
 
@@ -2898,12 +2920,13 @@ TfLiteStatus NNAPIDelegateKernel::Init(TfLiteContext* context,
 
   if (!nn_compilation_) {
     ANeuralNetworksCompilation* compilation = nullptr;
-    if (nnapi_device_ != nullptr) {
+    if (!nnapi_devices_.empty()) {
       // Compile for the selected accelerator.
       RETURN_TFLITE_ERROR_IF_NN_ERROR(
           context,
           nnapi_->ANeuralNetworksCompilation_createForDevices(
-              nn_model_.get(), &nnapi_device_, 1, &compilation),
+              nn_model_.get(), nnapi_devices_.data(), nnapi_devices_.size(),
+              &compilation),
           nnapi_errno);
     } else {
       RETURN_TFLITE_ERROR_IF_NN_ERROR(context,
@@ -3587,6 +3610,7 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(Options options)
   if (options.model_token) {
     delegate_data_.model_token = options.model_token;
   }
+  delegate_data_.disallow_nnapi_cpu = options.disallow_nnapi_cpu;
   TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
                        "Created TensorFlow Lite delegate for NNAPI.");
   Prepare = DoPrepare;
@@ -3613,6 +3637,7 @@ const StatefulNnApiDelegate::Options StatefulNnApiDelegate::GetOptions(
   options.model_token = delegate_data->model_token.empty()
                             ? nullptr
                             : delegate_data->model_token.c_str();
+  options.disallow_nnapi_cpu = delegate_data->disallow_nnapi_cpu;
   return options;
 }
 
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
index 9fdbe626320..022e9ed53ac 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
@@ -63,6 +63,12 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
     // NOTE: when using compilation caching, it is not recommended to use the
     // same delegate instance for multiple models.
     const char* model_token = nullptr;
+
+    // Whether to disallow NNAPI CPU usage. Only effective on Android 10 and
+    // above. The NNAPI CPU typically performs less well than built-in TfLite
+    // kernels, but allowing CPU allows partial acceleration of models. If this
+    // is set to true, NNAPI is only used if the whole model is accelerated.
+    bool disallow_nnapi_cpu = false;
   };
 
   // Uses default options.
@@ -131,6 +137,8 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
     std::string cache_dir;
     // The unique token string for NNAPI model.
     std::string model_token;
+    // Whether to disallow NNAPI CPU.
+    bool disallow_nnapi_cpu;
     // Tensor to ANeuralNetworksMemory mapping.
     std::vector tensor_memory_map;
     // Constains a non zero value if any NNAPI method call
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc
new file mode 100644
index 00000000000..146bf1eaa47
--- /dev/null
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc
@@ -0,0 +1,190 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include 
+
+#include 
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
+
+namespace tflite {
+namespace {
+
+class SingleOpModelWithNNAPI : public SingleOpModel {
+ public:
+  SingleOpModelWithNNAPI() = default;
+  void Init(tflite::StatefulNnApiDelegate::Options options) {
+    stateful_delegate_.reset(new StatefulNnApiDelegate(options));
+    auto* delegate = stateful_delegate_.get();
+    this->SetApplyDelegate([delegate, this](Interpreter* interpreter) {
+      compilation_status_ = interpreter->ModifyGraphWithDelegate(delegate);
+    });
+  }
+
+  StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }
+
+  void SetBufferHandle(int index, TfLiteBufferHandle handle) {
+    interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get());
+  }
+  TfLiteStatus GetCompilationStatus() { return compilation_status_; }
+
+ private:
+  std::unique_ptr stateful_delegate_;
+  TfLiteStatus compilation_status_;
+};
+
+class FloatAddOpModel : public SingleOpModelWithNNAPI {
+ public:
+  FloatAddOpModel() = default;
+  void Init(tflite::StatefulNnApiDelegate::Options options,
+            const TensorData& input1, const TensorData& input2,
+            const TensorData& output, ActivationFunctionType activation_type,
+            bool allow_fp32_relax_to_fp16 = false) {
+    SingleOpModelWithNNAPI::Init(options);
+    input1_ = AddInput(input1);
+    input2_ = AddInput(input2);
+    output_ = AddOutput(output);
+    SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
+                 CreateAddOptions(builder_, activation_type).Union());
+    BuildInterpreter({GetShape(input1_), GetShape(input2_)},
+                     allow_fp32_relax_to_fp16);
+  }
+
+  int input1() { return input1_; }
+  int input2() { return input2_; }
+
+  std::vector GetOutput() { return ExtractVector(output_); }
+
+ protected:
+  int input1_;
+  int input2_;
+  int output_;
+
+ private:
+};
+
+struct NnApiDeviceSelectionTest
+    : ::tflite::delegate::nnapi::NnApiDelegateMockTest {
+  void SetUp() override {
+    ::tflite::delegate::nnapi::NnApiDelegateMockTest::SetUp();
+    nnapi_->ANeuralNetworks_getDeviceCount = [](uint32_t* numDevices) -> int {
+      *numDevices = 3;
+      return 0;
+    };
+    nnapi_->ANeuralNetworks_getDevice =
+        [](uint32_t devIndex, ANeuralNetworksDevice** device) -> int {
+      *device = reinterpret_cast(devIndex + 1);
+      return 0;
+    };
+    nnapi_->ANeuralNetworksDevice_getName =
+        [](const ANeuralNetworksDevice* device, const char** name) -> int {
+      if (device == reinterpret_cast(1)) {
+        *name = "dsp";
+      } else if (device == reinterpret_cast(2)) {
+        *name = "gpu";
+      } else {
+        *name = "nnapi-reference";
+      }
+      return 0;
+    };
+  }
+  void InitWithOptions(tflite::StatefulNnApiDelegate::Options options) {
+    m.Init(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
+           {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}},
+           ActivationFunctionType_NONE);
+    m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+    m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5});
+  }
+  FloatAddOpModel m;
+};
+
+TEST_F(NnApiDeviceSelectionTest, DoesntSetDevicesWithoutFlags) {
+  nnapi_->ANeuralNetworksCompilation_createForDevices =
+      [](ANeuralNetworksModel* model,
+         const ANeuralNetworksDevice* const* devices, uint32_t numDevices,
+         ANeuralNetworksCompilation** compilation) -> int {
+    EXPECT_TRUE(false) << "Should not call createForDevices";
+    return 1;
+  };
+
+  tflite::StatefulNnApiDelegate::Options options;
+  InitWithOptions(options);
+  m.Invoke();
+  EXPECT_EQ(m.GetCompilationStatus(), kTfLiteOk);
+}
+
+TEST_F(NnApiDeviceSelectionTest, SetsDeviceBasedOnOptions) {
+  nnapi_mock_->CompilationCreateReturns<1>();
+  nnapi_->ANeuralNetworksCompilation_createForDevices =
+      [](ANeuralNetworksModel* model,
+         const ANeuralNetworksDevice* const* devices, uint32_t numDevices,
+         ANeuralNetworksCompilation** compilation) -> int {
+    EXPECT_EQ(numDevices, 1);
+    EXPECT_EQ(devices[0], reinterpret_cast(1));
+    if (numDevices != 1 ||
+        devices[0] != reinterpret_cast(1)) {
+      return 1;
+    } else {
+      *compilation = reinterpret_cast(3);
+      return 0;
+    }
+  };
+
+  tflite::StatefulNnApiDelegate::Options options;
+  options.accelerator_name = "dsp";
+  InitWithOptions(options);
+  m.Invoke();
+  EXPECT_EQ(m.GetCompilationStatus(), kTfLiteOk);
+}
+
+TEST_F(NnApiDeviceSelectionTest, DisallowsCPUBasedOnOptions) {
+  nnapi_mock_->CompilationCreateReturns<1>();
+  nnapi_->ANeuralNetworksCompilation_createForDevices =
+      [](ANeuralNetworksModel* model,
+         const ANeuralNetworksDevice* const* devices, uint32_t numDevices,
+         ANeuralNetworksCompilation** compilation) -> int {
+    EXPECT_EQ(numDevices, 2);
+    EXPECT_EQ(devices[0], reinterpret_cast(1));
+    EXPECT_EQ(devices[1], reinterpret_cast(2));
+    if (numDevices != 2 ||
+        devices[0] != reinterpret_cast(1) ||
+        devices[1] != reinterpret_cast(2)) {
+      return 1;
+    } else {
+      *compilation = reinterpret_cast(3);
+      return 0;
+    }
+  };
+
+  tflite::StatefulNnApiDelegate::Options options;
+  options.disallow_nnapi_cpu = true;
+  InitWithOptions(options);
+  m.Invoke();
+  EXPECT_EQ(m.GetCompilationStatus(), kTfLiteOk);
+}
+
+}  // namespace
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
index 5390b181583..6a9493f9f4d 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
@@ -285,7 +285,7 @@ class NNAPIDelegateKernel {
   // Access to NNApi.
   const NnApi* nnapi_;
   // ANN device handle.
-  ANeuralNetworksDevice* nnapi_device_ = nullptr;
+  std::vector nnapi_devices_;
   // ANN API state.
   std::unique_ptr nn_model_;
   std::unique_ptr
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
index 24eb06edabe..4a48409de1e 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
@@ -73,12 +73,12 @@ class NnApiMock : public ::tflite::nnapi::NnApiHandler {
 };
 
 class NnApiDelegateMockTest : public ::testing::Test {
+ protected:
   void SetUp() override {
     nnapi_ = const_cast(NnApiImplementation());
     nnapi_mock_ = absl::make_unique(nnapi_);
   }
 
- protected:
   NnApi* nnapi_;
   std::unique_ptr nnapi_mock_;
 };

From 456e027373d553ad10c15ceaeee1eb78041b79dc Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 01:03:18 -0800
Subject: [PATCH 132/279] compat: Update forward compatibility horizon to
 2019-12-02

PiperOrigin-RevId: 283291364
Change-Id: I2a27ee64285df676789cbf41a9ff4086b9848f9d
---
 tensorflow/python/compat/compat.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 14c7186d7b4..69e6fdd100a 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 1)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 2)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 

From c80cfd95e069c6868f7e3507443c4a5d1ec185ed Mon Sep 17 00:00:00 2001
From: Chris Jones 
Date: Mon, 2 Dec 2019 03:22:01 -0800
Subject: [PATCH 133/279] Remove distribution strategy device map code.

PiperOrigin-RevId: 283308832
Change-Id: I7d1c4fd981a29fa07f45317539cc9e4ef120c308
---
 .../collective_all_reduce_strategy.py         |   5 +-
 .../python/distribute/cross_device_ops.py     |  89 ++-
 .../distribute/cross_device_ops_test.py       |  17 +-
 .../distribute/cross_device_utils_test.py     |   3 +-
 .../python/distribute/distribute_lib_test.py  |   5 +-
 tensorflow/python/distribute/input_lib.py     |  45 +-
 .../python/distribute/input_lib_test.py       |  11 +-
 .../distribute/mirrored_function_strategy.py  |   8 +-
 .../python/distribute/mirrored_strategy.py    |  87 ++-
 .../distribute/mirrored_strategy_test.py      |  33 +-
 .../distribute/mirrored_variable_test.py      |  53 +-
 .../python/distribute/one_device_strategy.py  |   4 +-
 .../distribute/parameter_server_strategy.py   |  28 +-
 tensorflow/python/distribute/tpu_strategy.py  |  35 +-
 tensorflow/python/distribute/values.py        | 525 +++---------------
 tensorflow/python/distribute/values_test.py   | 427 +++++---------
 .../keras/distribute/keras_utils_test.py      |  10 +-
 tensorflow/python/module/module_test.py       |   5 +-
 .../python/ops/stateful_random_ops_test.py    |   4 +-
 tensorflow/python/tpu/tpu.py                  |  29 +-
 tensorflow/python/training/optimizer.py       |   2 +-
 21 files changed, 407 insertions(+), 1018 deletions(-)

diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index 507e7779cfe..89d13c0777f 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -209,6 +209,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
         local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
       else:
         local_devices = ("/device:CPU:0",)
+
     self._worker_device = device_util.canonicalize("/device:CPU:0")
     self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
 
@@ -327,7 +328,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
     super(CollectiveAllReduceExtended, self)._initialize_single_worker(
         local_devices)
     self._input_workers = input_lib.InputWorkers(
-        self._device_map, [(self._worker_device, self.worker_devices)])
+        [(self._worker_device, self.worker_devices)])
 
     # Add a default device so that ops without specified devices will not end up
     # on other workers.
@@ -523,7 +524,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
       # replicas in which case `value` would be a single value or value could
       # be 0.
       return cross_device_ops_lib.reduce_non_distributed_value(
-          reduce_op, self._device_map, value, destinations)
+          reduce_op, value, destinations, len(self.worker_devices))
     return self._get_cross_device_ops().reduce(
         reduce_op, value, destinations=destinations)
 
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index 9fc49df0ead..07aab81587a 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -65,10 +65,7 @@ def validate_destinations(destinations):
           ops.Tensor,
           value_lib.AggregatingVariable,
           six.string_types,
-          value_lib.TPUMirroredVariable,
-          # LogicalDeviceSpec is only used internally, e.g. as a
-          # broadcast destination, never supplied by a user.
-          value_lib.LogicalDeviceSpec)):
+          value_lib.TPUMirroredVariable)):
     raise ValueError("destinations must be one of a `DistributedValues` object,"
                      " a tf.Variable object, or a device string.")
 
@@ -76,7 +73,8 @@ def validate_destinations(destinations):
     raise ValueError("destinations can not be empty")
 
 
-def reduce_non_distributed_value(reduce_op, device_map, value, destinations):
+def reduce_non_distributed_value(
+    reduce_op, value, destinations, num_replicas_in_graph):
   """Reduce a non-DistributedValue `value` to `destinations`."""
   if isinstance(value, value_lib.DistributedValues):
     raise ValueError("You are passing a `DistributedValue` to "
@@ -92,15 +90,16 @@ def reduce_non_distributed_value(reduce_op, device_map, value, destinations):
   # that value should be on all destinations.
   if reduce_op == reduce_util.ReduceOp.MEAN:
     return value
-
-  validate_destinations(destinations)
-  # We do not support a reduce op of SUM if the value is the same across
-  # all replicas. We call this as part of assign functions for MirroredVariables
-  # and summing up identical values across replicas is not clearly defined.
-  if device_map.num_replicas_in_graph != 1:
+  elif num_replicas_in_graph != 1:
+    # We do not support a reduce op of SUM if the value is the same across
+    # all replicas. We call this as part of assign functions for
+    # MirroredVariables and summing up identical values across replicas is not
+    # clearly defined.
     raise ValueError("A non-DistributedValues value %s cannot be reduced with "
                      "the given reduce op %s." % (value, reduce_op))
-  return simple_broadcast(value, destinations)
+  else:
+    validate_destinations(destinations)
+    return simple_broadcast(value, destinations)
 
 
 def _make_tensor_into_per_replica(input_tensor):
@@ -111,16 +110,12 @@ def _make_tensor_into_per_replica(input_tensor):
                      % (input_tensor,))
   if isinstance(input_tensor, value_lib.PerReplica):
     return input_tensor
-
-  try:
-    device = input_tensor.device
-  except AttributeError:
+  elif hasattr(input_tensor, "device"):
+    return value_lib.PerReplica((input_tensor,))
+  else:
     raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
                      "because it doesn't have device set.")
 
-  device_map = value_lib.SingleDeviceMap(device)
-  return value_lib.PerReplica(device_map, (input_tensor,))
-
 
 def _normalize_value_destination_pairs(value_destination_pairs):
   """Converts each tensor into a PerReplica object in the input list."""
@@ -161,25 +156,11 @@ def _validate_value_destination_pairs(value_destination_pairs):
 def get_devices_from(destinations):
   if isinstance(destinations, value_lib.DistributedValues):
     return destinations.devices
-  elif isinstance(destinations, value_lib.LogicalDeviceSpec):
-    return destinations.device_map.logical_to_actual_devices(
-        destinations.logical_device)
   elif isinstance(destinations, six.string_types):
     return (device_util.resolve(destinations),)
   return (device_util.resolve(destinations.device),)
 
 
-def get_device_map_from(destinations):
-  if isinstance(destinations, (value_lib.DistributedValues,
-                               value_lib.LogicalDeviceSpec)):
-    return destinations.device_map, destinations.logical_device
-  if isinstance(destinations, six.string_types):
-    device = device_util.resolve(destinations)
-  else:
-    device = destinations.device
-  return value_lib.SingleDeviceMap(device), 0
-
-
 def _devices_match(left, right):
   return set(get_devices_from(left)) == set(get_devices_from(right))
 
@@ -195,8 +176,7 @@ def _all_devices_match(value_destination_pairs):
 
 def simple_broadcast(value, destinations, always_mirrored=False):
   """Broadcast `value` to `destinations` using simple copies."""
-  device_map, logical_device = get_device_map_from(destinations)
-  devices = device_map.logical_to_actual_devices(logical_device)
+  devices = get_devices_from(destinations)
   if len(devices) == 1 and not always_mirrored:
     return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
         value, devices[0])
@@ -204,10 +184,8 @@ def simple_broadcast(value, destinations, always_mirrored=False):
     value_updates = []
     for d in devices:
       value_updates.append(
-          cross_device_utils.copy_tensor_or_indexed_slices_to_device(
-              value, d))
-    return value_lib.regroup(
-        device_map, value_updates, wrap_class=value_lib.Mirrored)
+          cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
+    return value_lib.regroup(value_updates, wrap_class=value_lib.Mirrored)
 
 
 def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
@@ -274,7 +252,6 @@ class CrossDeviceOps(object):
         per_replica_value.values) == 1 and _devices_match(
             per_replica_value, destinations):
       return value_lib.regroup(
-          per_replica_value.device_map,
           per_replica_value.values,
           wrap_class=value_lib.Mirrored)
 
@@ -319,8 +296,7 @@ class CrossDeviceOps(object):
         value_destination_pairs) and len(
             value_destination_pairs[0][0].values) == 1:
       return [
-          value_lib.regroup(
-              v.device_map, v.values, wrap_class=value_lib.Mirrored)
+          value_lib.regroup(v.values, wrap_class=value_lib.Mirrored)
           for v, _ in value_destination_pairs
       ]
 
@@ -498,8 +474,7 @@ def _ungroup_and_make_mirrored(grouped_reduced,
   Returns:
     a list of Mirrored objects.
   """
-  device_map, _ = get_device_map_from(destinations)
-  num_replicas = device_map.num_replicas_in_graph * num_between_graph_workers
+  num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
   index = [[] for _ in range(len(grouped_reduced[0]))]
   for per_replica_reduced in grouped_reduced:
     for i, (v, _) in enumerate(per_replica_reduced):
@@ -508,10 +483,7 @@ def _ungroup_and_make_mirrored(grouped_reduced,
           index[i].append(v / num_replicas)
       else:
         index[i].append(v)
-  return [
-      value_lib.regroup(device_map, v, wrap_class=value_lib.Mirrored)
-      for v in index
-  ]
+  return [value_lib.regroup(v, wrap_class=value_lib.Mirrored) for v in index]
 
 
 class _ConcatAndSplitPacker(object):
@@ -1036,32 +1008,33 @@ class CollectiveAllReduce(CrossDeviceOps):
 
   def reduce_implementation(self, reduce_op, per_replica_value, destinations):
     all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
-    device_map, logical_device = get_device_map_from(destinations)
-    devices = device_map.logical_to_actual_devices(logical_device)
+    devices = get_devices_from(destinations)
 
     if (isinstance(all_reduced, value_lib.Mirrored) and
-        all_reduced.device_map is device_map and
-        all_reduced.logical_device == logical_device):
+        (all_reduced.devices == devices)):
       return all_reduced
 
     # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
     # utility to access component for a particular device.
     if not isinstance(all_reduced, value_lib.Mirrored):
-      all_reduced = value_lib.Mirrored(
-          value_lib.SingleDeviceMap(all_reduced.device), [all_reduced])
+      all_reduced = value_lib.Mirrored([all_reduced])
 
+    # If we got this far, the destination devices do not match the all-reduce
+    # devices, so we must map from one to the other.
     index = []
+    # We must add these control dependencies, otherwise we can get deadlock.
     with ops.control_dependencies(all_reduced.values):
       for d in devices:
         with ops.device(d):
-          if d in all_reduced.devices:
-            index.append(array_ops.identity(all_reduced.get(d)))
+          for v in all_reduced.values:
+            if v.device == d:
+              index.append(array_ops.identity(v))
+              break
           else:
             # TODO(josh11b): Once we add support for model parallelism, get the
             # copy from the corresponding replica instead of the primary.
             index.append(array_ops.identity(all_reduced.primary))
-
-    return value_lib.regroup(device_map, index, wrap_class=value_lib.Mirrored)
+    return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
 
   def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
     all_devices_match = _all_devices_match(value_destination_pairs)
diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py
index 0ec049ad4c1..7af7c48e57d 100644
--- a/tensorflow/python/distribute/cross_device_ops_test.py
+++ b/tensorflow/python/distribute/cross_device_ops_test.py
@@ -66,7 +66,7 @@ def _make_per_replica(values, devices, regroup=False):
     with ops.device(d):
       placed_v = array_ops.identity(v)
     index.append(placed_v)
-  return value_lib.regroup(value_lib.ReplicaDeviceMap(devices), index)
+  return value_lib.regroup(index)
 
 
 # pylint: disable=g-doc-args,g-doc-return-or-yield
@@ -82,7 +82,6 @@ def _fake_mirrored(value, devices):
     with ops.device(d):
       values.append(array_ops.identity(value))
   return value_lib.regroup(
-      value_lib.ReplicaDeviceMap(devices),
       values,
       wrap_class=value_lib.Mirrored)
 
@@ -100,7 +99,6 @@ def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
   values = [_make_indexed_slices(values, indices, dense_shape, d)
             for d in devices]
   return value_lib.regroup(
-      value_lib.ReplicaDeviceMap(devices),
       values,
       wrap_class=value_lib.Mirrored)
 
@@ -127,8 +125,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
     else:
       if isinstance(left, value_lib.DistributedValues):
         self.assertEqual(set(left.devices), set(right.devices))
-        self._assert_values_equal([left.get(d) for d in sorted(left.devices)],
-                                  [right.get(d) for d in sorted(right.devices)])
+        self._assert_values_equal(left.values, right.values)
       else:
         self.assertEqual(
             device_util.resolve(left.device), device_util.resolve(right.device))
@@ -217,8 +214,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
     t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
     t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], dense_shape,
                               devices[1])
-    per_replica = value_lib.PerReplica(
-        value_lib.ReplicaDeviceMap(devices), (t0, t1))
+    per_replica = value_lib.PerReplica((t0, t1))
 
     if batch_reduce:
       result = cross_device_ops_instance.batch_reduce(
@@ -339,7 +335,6 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
           cross_device_ops_lib.choose_the_best(devices),
           cross_device_ops_lib.ReductionToOneDevice)
 
-
   @combinations.generate(combinations.combine(
       mode=["graph", "eager"],
       required_gpus=1))
@@ -347,8 +342,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
     devices = ["/cpu:0", "/gpu:0"]
     t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
     t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
-    per_replica = value_lib.PerReplica(
-        value_lib.ReplicaDeviceMap(devices), (t0, t1))
+    per_replica = value_lib.PerReplica((t0, t1))
     result = cross_device_ops_lib._simple_reduce(
         per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM)
 
@@ -648,8 +642,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
       indexed_slices.append(
           _make_indexed_slices(values[idx], indices[idx], dense_shape, d))
     if as_per_replica:
-      per_replica = value_lib.PerReplica(
-          value_lib.ReplicaDeviceMap(devices), indexed_slices)
+      per_replica = value_lib.PerReplica(indexed_slices)
       return per_replica
     else:
       return indexed_slices
diff --git a/tensorflow/python/distribute/cross_device_utils_test.py b/tensorflow/python/distribute/cross_device_utils_test.py
index 16caad7615a..217883ea21b 100644
--- a/tensorflow/python/distribute/cross_device_utils_test.py
+++ b/tensorflow/python/distribute/cross_device_utils_test.py
@@ -103,8 +103,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
         constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
     t1 = math_ops._as_indexed_slices(
         constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
-    device_map = value_lib.ReplicaDeviceMap(("/gpu:0", "/cpu:0"))
-    per_replica = value_lib.PerReplica(device_map, (t0, t1))
+    per_replica = value_lib.PerReplica((t0, t1))
     self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica))
 
   @combinations.generate(combinations.combine(
diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py
index fb8116d4ab2..fe97d37f417 100644
--- a/tensorflow/python/distribute/distribute_lib_test.py
+++ b/tensorflow/python/distribute/distribute_lib_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.distribute import input_lib
 from tensorflow.python.distribute import reduce_util
-from tensorflow.python.distribute import values
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -67,10 +66,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
 
   def __init__(self, distribute):
     super(_TestExtended, self).__init__(distribute)
-    device_map = values.ReplicaDeviceMap(["/device:CPU:0"])
     worker_device_pairs = [("", ["/device:CPU:0"])]
-    self._input_workers = input_lib.InputWorkers(device_map,
-                                                 worker_device_pairs)
+    self._input_workers = input_lib.InputWorkers(worker_device_pairs)
 
   def _call_for_each_replica(self, fn, args, kwargs):
     with _TestReplicaContext(
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index f1f9a0e872d..2b92b7f7c22 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -130,40 +130,16 @@ def get_distributed_datasets_from_function(dataset_fn,
 class InputWorkers(object):
   """A 1-to-many mapping from input worker devices to compute devices."""
 
-  def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
+  def __init__(self, worker_device_pairs):
     """Initialize an `InputWorkers` object.
 
     Args:
-      device_map: A `DeviceMap` with the computation devices fed by the
-        input workers.
       worker_device_pairs: A sequence of pairs:
         `(input device, a tuple of compute devices fed by that input device)`.
-      logical_device: The logical device of `device_map` to feed.
     """
-    self._device_map = device_map
-    self._logical_device = logical_device
-    if worker_device_pairs is None:
-      devices = device_map.logical_to_actual_devices(logical_device)
-      worker_device_pairs = ((
-          device_util.canonicalize("/device:CPU:0", devices[0]),
-          devices),)
     self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
     self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
                               for _, f in worker_device_pairs)
-    flattened = tuple(d for l in self._fed_devices for d in l)
-    assert (flattened ==
-            device_map.logical_to_actual_devices(logical_device)), (
-                "flattened: %s logical device %d: %s" %
-                (flattened, logical_device,
-                 device_map.logical_to_actual_devices(logical_device)))
-
-  @property
-  def device_map(self):
-    return self._device_map
-
-  @property
-  def logical_device(self):
-    return self._logical_device
 
   @property
   def num_workers(self):
@@ -181,8 +157,7 @@ class InputWorkers(object):
     debug_repr = ",\n".join("  %d %s: %s" %
                             (i, devices[i], self._fed_devices[i])
                             for i in range(len(devices)))
-    return "%s:{\n%s\n  device_map: %s}" % (
-        self.__class__.__name__, debug_repr, self._device_map)
+    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
 
 
 def _get_next_as_optional(iterator, strategy, name=None):
@@ -213,12 +188,9 @@ def _get_next_as_optional(iterator, strategy, name=None):
   # TODO(b/131423105): we should be able to short-cut the all-reduce in some
   # cases.
   if getattr(strategy.extended, "_support_per_replica_values", True):
-    worker_has_values = values.PerReplica(
-        values.WorkerDeviceMap(
-            worker_devices,
-            num_replicas_per_worker=len(
-                strategy.extended._input_workers._input_worker_devices)),  # pylint: disable=protected-access
-        worker_has_values)
+    # Slight hack: `reduce` expects a `PerReplica`, so we pass it one, even
+    # though it doesn't actually have a value per replica.
+    worker_has_values = values.PerReplica(worker_has_values)
     global_has_value = strategy.reduce(
         reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
   else:
@@ -292,7 +264,7 @@ class DistributedIterator(object):
           # Make `replicas` a flat list of values across all replicas.
           replicas.extend(
               self._iterators[i].get_next_as_list_static_shapes(new_name))
-      return values.regroup(self._input_workers.device_map, replicas)
+      return values.regroup(replicas)
 
     out_of_range_replicas = []
     def out_of_range_fn(worker_index, device):
@@ -352,7 +324,7 @@ class DistributedIterator(object):
               dense_shape=dense_shape)
     replicas = nest.pack_sequence_as(replicas, flattened_replicas)
 
-    return values.regroup(self._input_workers.device_map, replicas)
+    return values.regroup(replicas)
 
   # We need a private initializer method for re-initializing multidevice
   # iterators when used with Keras training loops. If we don't reinitialize the
@@ -459,8 +431,7 @@ class _IterableInput(object):
       else:
         raise ValueError("Dataset iteration within a tf.function is"
                          " not supported for multiple workers.")
-      per_replica_data = values.regroup(self._input_workers.device_map, data)
-      state = reduce_fn(state, per_replica_data)
+      state = reduce_fn(state, values.regroup(data))
       has_data, data = _get_next_as_optional(iterator, self._strategy)
       return has_data, data, state
 
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index 96363053219..433d18d36cb 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -138,8 +138,7 @@ class DistributedIteratorTestBase(test.TestCase):
       self.skipTest("unsupported test combination.")
 
     devices = nest.flatten([ds for _, ds in worker_device_pairs])
-    device_map = values.ReplicaDeviceMap(devices)
-    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
+    input_workers = input_lib.InputWorkers(worker_device_pairs)
 
     if api_type == "wrap_into_iterator":
       iterator = self._wrap_iterator(
@@ -236,9 +235,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
     worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
     dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
 
-    devices = nest.flatten([ds for _, ds in worker_device_pairs])
-    device_map = values.ReplicaDeviceMap(devices)
-    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
+    input_workers = input_lib.InputWorkers(worker_device_pairs)
 
     dist_dataset = input_lib.get_distributed_dataset(
         dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
@@ -260,9 +257,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
           ]))
   def testDatasetV2IterError(self, distribution):
     worker_device_pairs = [("", ["/device:CPU:0"])]
-    devices = nest.flatten([ds for _, ds in worker_device_pairs])
-    device_map = values.ReplicaDeviceMap(devices)
-    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
+    input_workers = input_lib.InputWorkers(worker_device_pairs)
     dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
 
     dist_dataset = input_lib.get_distributed_dataset(
diff --git a/tensorflow/python/distribute/mirrored_function_strategy.py b/tensorflow/python/distribute/mirrored_function_strategy.py
index aa81aaabfe0..aa9ecfa1fc4 100644
--- a/tensorflow/python/distribute/mirrored_function_strategy.py
+++ b/tensorflow/python/distribute/mirrored_function_strategy.py
@@ -91,8 +91,7 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1):
     device_tuple = tuple(device_util.resolve(d) for d in devices)
     assert len(set(device_tuple)) == len(device_tuple), (
         "No duplicates allowed in `devices` argument: %s" % (devices,))
-    self._device_map = values.ReplicaDeviceMap(device_tuple)
-
+    self._devices = device_tuple
     self._retrace_functions_for_each_device = False
 
   def _call_for_each_replica(self, fn, args, kwargs):
@@ -116,8 +115,7 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1):
 
     try:
       with MirroredFunctionReplicaContext(self._container_strategy()):
-        for index, device in enumerate(
-            self._device_map.logical_to_actual_devices(0)):
+        for index, device in enumerate(self._devices):
           _replica_index.current = index
           with ops.device(device):
             if context.executing_eagerly():
@@ -134,7 +132,7 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1):
       _replica_index.graph_outside_run = None
       _replica_index.current = None
 
-    return values.regroup(self._device_map, return_values)
+    return values.regroup(return_values)
 
   def _local_results(self, val):
     if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index 0fb8ae0aafb..b45c52e9ad6 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -89,12 +89,12 @@ class _RequestedStop(Exception):  # pylint: disable=g-bad-exception-name
 
 # TODO(yuefengz): maybe create a common class for those who need to call this
 # _call_for_each_replica.
-def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
+def _call_for_each_replica(distribution, devices, fn, args, kwargs):
   """Run `fn` in separate threads, once per replica/worker device.
 
   Args:
     distribution: the DistributionStrategy object.
-    device_map: the DeviceMap with the devices to run `fn` on.
+    devices: the devices to run `fn` on (logical device 0 for each replica).
     fn: function to run (will be run once per replica, each in its own thread).
     args: positional arguments for `fn`
     kwargs: keyword arguments for `fn`.
@@ -119,11 +119,11 @@ def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
 
   # TODO(isaprykin): Create these threads once instead of during every call.
   threads = []
-  for index in range(device_map.num_replicas_in_graph):
+  for index in range(len(devices)):
     variable_creator_fn = shared_variable_creator.make_fn(
         shared_variable_store, index)
     t = _MirroredReplicaThread(
-        distribution, coord, index, device_map, variable_creator_fn, fn,
+        distribution, coord, index, devices, variable_creator_fn, fn,
         values.select_replica(index, args),
         values.select_replica(index, kwargs))
     threads.append(t)
@@ -173,10 +173,8 @@ def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
             raise RuntimeError("Some replicas made a different number of "
                                "replica_context().merge_call() calls.")
           # get_replica_context().merge_call() case
-          merge_args = values.regroup(
-              device_map, tuple(t.merge_args for t in threads))
-          merge_kwargs = values.regroup(
-              device_map, tuple(t.merge_kwargs for t in threads))
+          merge_args = values.regroup(tuple(t.merge_args for t in threads))
+          merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads))
           # We capture the name_scope of the MRT when we call merge_fn
           # to ensure that if we have opened a name scope in the MRT,
           # it will be respected when executing the merge function. We only
@@ -200,7 +198,7 @@ def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
       t.should_run.set()
     coord.join(threads)
 
-  return values.regroup(device_map, tuple(t.main_result for t in threads))
+  return values.regroup(tuple(t.main_result for t in threads))
 
 
 def _is_device_list_single_worker(devices):
@@ -425,8 +423,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
 
   def _initialize_single_worker(self, devices):
     """Initializes the object for single-worker training."""
-    self._device_map = values.ReplicaDeviceMap(devices)
-    self._input_workers = input_lib.InputWorkers(self._device_map)
+    self._devices = tuple(device_util.canonicalize(d) for d in devices)
+    self._input_workers = input_lib.InputWorkers(
+        ((device_util.canonicalize("/device:CPU:0", devices[0]), devices),))
     self._inferred_cross_device_ops = None if self._cross_device_ops else (
         cross_device_ops_lib.choose_the_best(devices))
     self._host_input_device = numpy_dataset.SingleDevice(
@@ -461,9 +460,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     self._default_device = workers[0]
     self._host_input_device = numpy_dataset.SingleDevice(workers[0])
 
-    self._device_map = values.ReplicaDeviceMap(devices)
-    self._input_workers = input_lib.InputWorkers(
-        self._device_map, worker_devices)
+    self._devices = tuple(devices)
+    self._input_workers = input_lib.InputWorkers(worker_devices)
     self._is_multi_worker_training = True
 
     if len(workers) > 1:
@@ -508,16 +506,14 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     """Create a mirrored variable. See `DistributionStrategy.scope`."""
     colocate_with = kwargs.pop("colocate_with", None)
     if colocate_with is None:
-      device_map = self._device_map
-      logical_device = 0  # TODO(josh11b): Get logical device from scope here.
+      devices = self._devices
     elif isinstance(colocate_with, numpy_dataset.SingleDevice):
       with ops.device(colocate_with.device):
         return next_creator(*args, **kwargs)
     else:
-      device_map = colocate_with.device_map
-      logical_device = colocate_with.logical_device
+      devices = colocate_with.devices
 
-    def _real_mirrored_creator(devices, *args, **kwargs):  # pylint: disable=g-missing-docstring
+    def _real_mirrored_creator(*args, **kwargs):  # pylint: disable=g-missing-docstring
       value_list = []
       for i, d in enumerate(devices):
         with ops.device(d):
@@ -543,9 +539,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
       return value_list
 
     return values.create_mirrored_variable(
-        self._container_strategy(), device_map, logical_device,
-        _real_mirrored_creator, values.MirroredVariable,
-        values.SyncOnReadVariable, *args, **kwargs)
+        self._container_strategy(), _real_mirrored_creator,
+        values.MirroredVariable, values.SyncOnReadVariable, *args, **kwargs)
 
   def _validate_colocate_with_variable(self, colocate_with_variable):
     values.validate_colocate_distributed_variable(colocate_with_variable, self)
@@ -646,8 +641,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
       # For outputs that have already been reduced, wrap them in a Mirrored
       # container, else in a PerReplica container.
       if reduce_op is None:
-        last_step_tensor_outputs_dict[name] = values.regroup(self._device_map,
-                                                             output)
+        last_step_tensor_outputs_dict[name] = values.regroup(output)
       else:
         assert len(output) == 1
         last_step_tensor_outputs_dict[name] = output[0]
@@ -666,8 +660,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     # TODO(josh11b): In eager mode, use one thread per device, or async mode.
     if not destinations:
       # TODO(josh11b): Use current logical device instead of 0 here.
-      destinations = values.LogicalDeviceSpec(
-          device_map=self._device_map, logical_device=0)
+      destinations = self._devices
     return self._get_cross_device_ops().broadcast(tensor, destinations)
 
   def _call_for_each_replica(self, fn, args, kwargs):
@@ -690,7 +683,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
                           "`experimental_run_v2` inside a tf.function to get "
                           "the best performance." %
                           self._container_strategy().__class__.__name__, 5)
-    return _call_for_each_replica(self._container_strategy(), self._device_map,
+    return _call_for_each_replica(self._container_strategy(), self._devices,
                                   fn, args, kwargs)
 
   def _configure(self,
@@ -706,8 +699,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     if cluster_spec:
       # TODO(yuefengz): remove the following code once cluster_resolver is
       # added.
-      num_gpus_per_worker = _infer_num_gpus_per_worker(
-          self._device_map.all_devices)
+      num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices)
       multi_worker_devices = _cluster_spec_to_device_list(
           cluster_spec, num_gpus_per_worker)
       self._initialize_multi_worker(multi_worker_devices)
@@ -731,7 +723,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
       # replicas in which case `value` would be a single value or value could
       # be 0.
       return cross_device_ops_lib.reduce_non_distributed_value(
-          reduce_op, self._device_map, value, destinations)
+          reduce_op, value, destinations, self._num_replicas_in_sync)
     return self._get_cross_device_ops().reduce(
         reduce_op, value, destinations=destinations)
 
@@ -743,14 +735,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     # TODO(josh11b): In eager mode, use one thread per device.
     assert isinstance(var, values.DistributedVariable)
     updates = []
-    for i, (d, v) in enumerate(zip(var.devices, var.values)):
+    for i, v in enumerate(var.values):
       name = "update_%d" % i
-      with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
+      with ops.device(v.device), \
+           distribute_lib.UpdateContext(i), \
+           ops.name_scope(name):
         # If args and kwargs are not mirrored, the value is returned as is.
         updates.append(fn(v,
-                          *values.select_device_mirrored(d, args),
-                          **values.select_device_mirrored(d, kwargs)))
-    return values.update_regroup(self, self._device_map, updates, group)
+                          *values.select_replica_mirrored(i, args),
+                          **values.select_replica_mirrored(i, kwargs)))
+    return values.update_regroup(self, updates, group)
 
   def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
     assert isinstance(colocate_with, tuple)
@@ -759,9 +753,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     for i, d in enumerate(colocate_with):
       name = "update_%d" % i
       with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
-        updates.append(fn(*values.select_device_mirrored(d, args),
-                          **values.select_device_mirrored(d, kwargs)))
-    return values.update_regroup(self, self._device_map, updates, group)
+        updates.append(fn(*values.select_replica_mirrored(i, args),
+                          **values.select_replica_mirrored(i, kwargs)))
+    return values.update_regroup(self, updates, group)
 
   def read_var(self, replica_local_var):
     """Read the aggregate value of a replica-local variable."""
@@ -780,19 +774,19 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
 
   @property
   def _num_replicas_in_sync(self):
-    return self._device_map.num_replicas_in_graph
+    return len(self._devices)
 
   @property
   def worker_devices(self):
-    return self._device_map.all_devices
+    return self._devices
 
   @property
   def worker_devices_by_replica(self):
-    return self._device_map.devices_by_replica
+    return [[d] for d in self._devices]
 
   @property
   def parameter_devices(self):
-    return self._device_map.all_devices
+    return self.worker_devices
 
   @property
   def experimental_between_graph(self):
@@ -813,7 +807,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
   def non_slot_devices(self, var_list):
     del var_list
     # TODO(josh11b): Should this be the last logical device instead?
-    return self._device_map.logical_to_actual_devices(0)
+    return self._devices
 
   # TODO(priyag): Delete this once all strategies use global batch size.
   @property
@@ -835,12 +829,12 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
 class _MirroredReplicaThread(threading.Thread):
   """A thread that runs() a function on a device."""
 
-  def __init__(self, dist, coord, replica_id, device_map, variable_creator_fn,
+  def __init__(self, dist, coord, replica_id, devices, variable_creator_fn,
                fn, args, kwargs):
     super(_MirroredReplicaThread, self).__init__()
     self.coord = coord
     self.distribution = dist
-    self.device_map = device_map
+    self.devices = devices
     self.replica_id = replica_id
     self.variable_creator_fn = variable_creator_fn
     # State needed to run and return the results of `fn`.
@@ -908,8 +902,7 @@ class _MirroredReplicaThread(threading.Thread):
           context.device_policy(self.context_device_policy), \
           MirroredReplicaContext(self.distribution, constant_op.constant(
               self.replica_id, dtypes.int32)), \
-          ops.device(self.device_map.logical_to_actual_devices(0)[
-              self.replica_id]), \
+          ops.device(self.devices[self.replica_id]), \
           ops.name_scope(self._name_scope), \
           variable_scope.variable_scope(
               self._var_scope, reuse=self.replica_id > 0), \
diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py
index 32966c904d8..d2bc7ae7285 100644
--- a/tensorflow/python/distribute/mirrored_strategy_test.py
+++ b/tensorflow/python/distribute/mirrored_strategy_test.py
@@ -708,13 +708,21 @@ class MirroredVariableUpdateTest(test.TestCase):
       mirrored_var_result = self.evaluate(
           mirrored_var.assign_add(6.0, read_value=True))
       self.assertEqual(7.0, mirrored_var_result)
-      self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
-      self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+      self.assertEqual(7.0, self.evaluate(mirrored_var.values[0]))
+      self.assertEqual(7.0, self.evaluate(mirrored_var.values[1]))
+      self.assertEqual(
+          distribution.extended.worker_devices[0], mirrored_var.devices[0])
+      self.assertEqual(
+          distribution.extended.worker_devices[1], mirrored_var.devices[1])
 
       # read_value == False
       self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
-      self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
-      self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+      self.assertEqual(9.0, self.evaluate(mirrored_var.values[0]))
+      self.assertEqual(9.0, self.evaluate(mirrored_var.values[1]))
+      self.assertEqual(
+          distribution.extended.worker_devices[0], mirrored_var.devices[0])
+      self.assertEqual(
+          distribution.extended.worker_devices[1], mirrored_var.devices[1])
 
   def testAssignAddMirroredVarReplicaContext(self, distribution):
     def var_fn():
@@ -766,8 +774,12 @@ class MirroredVariableUpdateTest(test.TestCase):
       self.assertEqual(5.0, self.evaluate(mirrored_var))
       mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
       self.assertEqual(3.0, mirrored_var_result)
-      self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
-      self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+      self.assertEqual(3.0, self.evaluate(mirrored_var.values[0]))
+      self.assertEqual(3.0, self.evaluate(mirrored_var.values[1]))
+      self.assertEqual(
+          distribution.extended.worker_devices[0], mirrored_var.devices[0])
+      self.assertEqual(
+          distribution.extended.worker_devices[1], mirrored_var.devices[1])
 
   def testAssignSubMirroredVarReplicaContext(self, distribution):
     def var_fn():
@@ -978,8 +990,8 @@ class MirroredStrategyDefunTest(test.TestCase):
         per_replica_graph_functions = (
             distribution.extended.call_for_each_replica(
                 defun.get_concrete_function, args=[mock_model] + inputs))
-        for device in devices:
-          graph_function = per_replica_graph_functions.get(device=device)
+        for i in range(len(devices)):
+          graph_function = per_replica_graph_functions.values[i]
           # TODO(b/129555712): re-enable an assertion here that the two sets of
           # variables are the same.
           # self.assertEqual(set(graph_function.graph.variables),
@@ -1050,9 +1062,8 @@ class MirroredStrategyDefunTest(test.TestCase):
     def fn1(mock_model, factor):
       return mock_model(factor)
 
-    device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
-    factors = values.PerReplica(device_map, (5.0, 3.0))
-    expected_result = values.PerReplica(device_map, (5.0 * 1.25, 3.0 * 1.25))
+    factors = values.PerReplica((5.0, 3.0))
+    expected_result = values.PerReplica((5.0 * 1.25, 3.0 * 1.25))
     self._call_and_check(distribution, fn1, [factors], expected_result, [fn1])
 
   def testTrain(self, distribution):
diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py
index f237ee19205..3cc75451827 100644
--- a/tensorflow/python/distribute/mirrored_variable_test.py
+++ b/tensorflow/python/distribute/mirrored_variable_test.py
@@ -87,9 +87,9 @@ class MirroredVariableCreationTest(test.TestCase):
     self.assertIsInstance(var, values.MirroredVariable)
     self.assertEqual(name, var.name)
     self.assertIs(strategy, var.distribute_strategy)
-    for d in var.devices:
-      self.assertEqual(d, var.get(d).device)
-      self.assertIs(strategy, var.get(d)._distribute_strategy)  # pylint: disable=protected-access
+    for i, d in enumerate(var.devices):
+      self.assertEqual(d, var.values[i].device)
+      self.assertIs(strategy, var.values[i]._distribute_strategy)  # pylint: disable=protected-access
 
   def testVariableInFuncGraph(self, distribution):
 
@@ -323,16 +323,15 @@ class MirroredVariableCreationTest(test.TestCase):
           aggregation=aggregation)
       return v0, v1
 
-    devices = distribution.extended.worker_devices
     with distribution.scope():
       v0, v1 = distribution.extended.call_for_each_replica(create_fn)
       self.evaluate(v0.initializer)
-      self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
-      self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
+      self.assertEqual(2.0, self.evaluate(v0.values[0]))
+      self.assertEqual(2.0, self.evaluate(v0.values[1]))
       self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
       self.evaluate(v1.initializer)
-      self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
-      self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
+      self.assertEqual(3.0, self.evaluate(v1.values[0]))
+      self.assertEqual(3.0, self.evaluate(v1.values[1]))
       self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
 
       def replica_id_plus_one():
@@ -349,20 +348,20 @@ class MirroredVariableCreationTest(test.TestCase):
 
       # Update "sync on read" variable.
       self.evaluate(distribution.group(update0a))
-      self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
+      self.assertEqual(2.0 + 5.0, self.evaluate(v0.values[0]))
       # Writes are not synchronized for "sync on read" variables,
       # so device[1] can end up with a different value.
-      self.assertEqual(2.0 + 2 * 5.0, self.evaluate(v0.get(devices[1])))
+      self.assertEqual(2.0 + 2 * 5.0, self.evaluate(v0.values[1]))
       # Always reads from device 0.
       self.assertEqual(2.0 + 5.0,
                        self.evaluate(distribution.extended.read_var(v0)))
 
       # Update "sync on write" variable.
       self.evaluate(distribution.group(update1a))
-      self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
+      self.assertEqual(3.0 + 7.0, self.evaluate(v1.values[0]))
       # Writes are synchronized for v1, only the argument to assign_add on
       # device[0] is used.
-      self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
+      self.assertEqual(3.0 + 7.0, self.evaluate(v1.values[1]))
       self.assertEqual(3.0 + 7.0,
                        self.evaluate(distribution.extended.read_var(v1)))
 
@@ -377,16 +376,15 @@ class MirroredVariableCreationTest(test.TestCase):
       self.evaluate(distribution.group(update0b))
 
       # Update "sync on read" variable.
-      self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
-      self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0,
-                       self.evaluate(v0.get(devices[1])))
+      self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.values[0]))
+      self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0, self.evaluate(v0.values[1]))
       self.assertEqual(2.0 + 5.0 + 11.0,
                        self.evaluate(distribution.extended.read_var(v0)))
 
       # Update "sync on write" variable.
       self.evaluate(distribution.group(update1b))
-      self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
-      self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
+      self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.values[0]))
+      self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.values[1]))
       self.assertEqual(3.0 + 7.0 + 13.0,
                        self.evaluate(distribution.extended.read_var(v1)))
 
@@ -448,8 +446,7 @@ class MirroredVariableCreationTest(test.TestCase):
       return v
 
     with distribution.scope():
-      device_map = values.ReplicaDeviceMap(distribution.extended.worker_devices)
-      names = values.DistributedValues(device_map, ("foo", "bar"))
+      names = values.DistributedValues(("foo", "bar"))
       with self.assertRaises(RuntimeError):
         _ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
 
@@ -512,10 +509,10 @@ class MirroredVariableCreationTest(test.TestCase):
       ])
       expected_sum = 0.0
       expected_mean = 0.0
-      for i, d in enumerate(distribution.extended.worker_devices):
+      for i, _ in enumerate(distribution.extended.worker_devices):
         # Should see different values on different devices.
-        v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
-        v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
+        v_sum_value = self.evaluate(ret_v_sum.values[i].read_value())
+        v_mean_value = self.evaluate(ret_v_mean.values[i].read_value())
         expected = i + 3.0
         self.assertEqual(expected, v_sum_value)
         expected_sum += expected
@@ -578,11 +575,7 @@ class MirroredVariableCreationTest(test.TestCase):
       self.evaluate(variables.global_variables_initializer())
       # Assert that the aggregated value of the sync on read var is the sum
       # of the individual values before running the update ops.
-      self.assertEqual(
-          1.0,
-          self.evaluate(
-              ret_v_sum.get(
-                  distribution.extended.worker_devices[0]).read_value()))
+      self.assertEqual(1.0, self.evaluate(ret_v_sum.values[0].read_value()))
       self.assertEqual(2.0, self.evaluate(ret_v_sum))
 
       # Apply updates.
@@ -591,11 +584,7 @@ class MirroredVariableCreationTest(test.TestCase):
       self.evaluate(update_ops)
       # Assert that the aggregated value of the sync on read vars is the sum
       # of the individual values after running the update ops.
-      self.assertEqual(
-          5.0,
-          self.evaluate(
-              ret_v_sum.get(
-                  distribution.extended.worker_devices[0]).read_value()))
+      self.assertEqual(5.0, self.evaluate(ret_v_sum.values[0].read_value()))
       self.assertEqual(10.0, self.evaluate(ret_v_sum))
 
   def testVarDistributeStrategy(self, distribution):
diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py
index 7963a23c20f..144ce6a8fce 100644
--- a/tensorflow/python/distribute/one_device_strategy.py
+++ b/tensorflow/python/distribute/one_device_strategy.py
@@ -251,9 +251,7 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
     suffix_loc = self._device.rfind("/")
     self._input_device = self._device[:suffix_loc] + "/device:CPU:0"
     worker_device_pairs = [(self._input_device, [self._device])]
-    device_map = values.SingleDeviceMap(self._device)
-    self._input_workers = input_lib.InputWorkers(
-        device_map, worker_device_pairs)
+    self._input_workers = input_lib.InputWorkers(worker_device_pairs)
 
   def _create_variable(self, next_creator, *args, **kwargs):
     colocate_with = kwargs.pop("colocate_with", None)
diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py
index 1815fc2a669..d5305ed910a 100644
--- a/tensorflow/python/distribute/parameter_server_strategy.py
+++ b/tensorflow/python/distribute/parameter_server_strategy.py
@@ -213,9 +213,10 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
     else:
       compute_devices = (worker_device,)
 
-    self._device_map = values.ReplicaDeviceMap(compute_devices)
+    self._compute_devices = [
+        device_util.canonicalize(d) for d in compute_devices]
     self._input_workers = input_lib.InputWorkers(
-        self._device_map, [(worker_device, compute_devices)])
+        [(worker_device, compute_devices)])
 
     # In distributed mode, place variables on ps jobs in a round-robin fashion.
     # Note that devices returned from `replica_device_setter` are not
@@ -253,9 +254,9 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
     logging.info(
         "Multi-worker ParameterServerStrategy with "
         "cluster_spec = %r, task_type = %r, task_id = %r, "
-        "num_ps_replicas = %r, is_chief = %r, device_map = %r, "
+        "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
         "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
-        num_ps_replicas, self._is_chief, self._device_map,
+        num_ps_replicas, self._is_chief, self._compute_devices,
         self._variable_device)
 
   # TODO(yuefengz): get rid of cluster_resolver argument when contrib's
@@ -279,6 +280,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
 
       compute_devices = device_util.local_devices_from_num_gpus(num_gpus)
 
+    compute_devices = [device_util.canonicalize(d) for d in compute_devices]
+
     if parameter_device is None:
       # If there is only one GPU, put everything on that GPU. Otherwise, place
       # variables on CPU.
@@ -287,11 +290,11 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
       else:
         parameter_device = _LOCAL_CPU
 
-    self._device_map = values.ReplicaDeviceMap(compute_devices)
     self._input_workers = input_lib.InputWorkers(
-        self._device_map, [(worker_device, compute_devices)])
+        [(worker_device, compute_devices)])
 
     self._variable_device = parameter_device
+    self._compute_devices = compute_devices
     self._parameter_devices = (parameter_device,)
     self._is_chief = True
     self._cluster_spec = None
@@ -376,8 +379,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
       return tensor
     if not cross_device_ops_lib.check_destinations(destinations):
       # TODO(josh11b): Use current logical device instead of 0 here.
-      destinations = values.LogicalDeviceSpec(
-          device_map=self._device_map, logical_device=0)
+      destinations = self._compute_devices
     return self._cross_device_ops.broadcast(tensor, destinations)
 
   def _allow_variable_partition(self):
@@ -449,7 +451,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
   def _call_for_each_replica(self, fn, args, kwargs):
     # pylint: disable=protected-access
     return mirrored_strategy._call_for_each_replica(
-        self._container_strategy(), self._device_map, fn, args, kwargs)
+        self._container_strategy(), self._compute_devices, fn, args, kwargs)
 
   def _verify_destinations_not_different_worker(self, destinations):
     if not self._cluster_spec:
@@ -468,7 +470,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
     if not isinstance(value, values.DistributedValues):
       # pylint: disable=protected-access
       return cross_device_ops_lib.reduce_non_distributed_value(
-          reduce_op, self._device_map, value, destinations)
+          reduce_op, value, destinations, self._num_replicas_in_sync)
     return self._cross_device_ops.reduce(
         reduce_op, value, destinations=destinations)
 
@@ -605,15 +607,15 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
 
   @property
   def _num_replicas_in_sync(self):
-    return self._device_map.num_replicas_in_graph
+    return len(self._compute_devices)
 
   @property
   def worker_devices(self):
-    return self._device_map.all_devices
+    return self._compute_devices
 
   @property
   def worker_devices_by_replica(self):
-    return self._device_map.devices_by_replica
+    return [[d] for d in self._compute_devices]
 
   @property
   def parameter_devices(self):
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 2dd4309537a..3987bf390ff 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -201,8 +201,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 
     self._host_device = device_util.get_host_for_device(self._tpu_devices[0])
 
-    self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
-
     # Preload the data onto the TPUs.
     input_worker_devices = collections.OrderedDict()
     for tpu_device in self._tpu_devices:
@@ -210,7 +208,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
       input_worker_devices.setdefault(host_device, [])
       input_worker_devices[host_device].append(tpu_device)
     self._input_workers = input_lib.InputWorkers(
-        self._device_map, tuple(input_worker_devices.items()))
+        tuple(input_worker_devices.items()))
 
     # TODO(sourabhbajaj): Remove this once performance of running one step
     # at a time is comparable to multiple steps.
@@ -395,16 +393,14 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 
     colocate_with = kwargs.pop("colocate_with", None)
     if colocate_with is None:
-      device_map = self._device_map
-      logical_device = 0  # TODO(josh11b): Get logical device from scope here.
+      devices = self._tpu_devices
     elif isinstance(colocate_with, numpy_dataset.SingleDevice):
       with ops.device(colocate_with.device):
         return next_creator(*args, **kwargs)
     else:
-      device_map = colocate_with.device_map
-      logical_device = colocate_with.logical_device
+      devices = colocate_with.devices
 
-    def _real_mirrored_creator(devices, *args, **kwargs):  # pylint: disable=g-missing-docstring
+    def _real_mirrored_creator(*args, **kwargs):  # pylint: disable=g-missing-docstring
       initial_value = None
       value_list = []
       for i, d in enumerate(devices):
@@ -434,9 +430,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
       return value_list
 
     return values.create_mirrored_variable(
-        self._container_strategy(), device_map, logical_device,
-        _real_mirrored_creator, values.TPUMirroredVariable,
-        values.TPUSyncOnReadVariable, *args, **kwargs)
+        self._container_strategy(), _real_mirrored_creator,
+        values.TPUMirroredVariable, values.TPUSyncOnReadVariable,
+        *args, **kwargs)
 
   def _reduce_to(self, reduce_op, value, destinations):
     if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
@@ -454,7 +450,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
       # replicas in which case `value` would be a single value or value could
       # be 0.
       return cross_device_ops_lib.reduce_non_distributed_value(
-          reduce_op, self._device_map, value, destinations)
+          reduce_op, value, destinations, self._num_replicas_in_sync)
 
     # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
     # Always performs the reduction on the TPU host.
@@ -490,14 +486,16 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
     # Otherwise, we revert to MirroredStrategy behavior and update each variable
     # directly.
     updates = []
-    for i, (d, v) in enumerate(zip(var.devices, var.values)):
+    for i, v in enumerate(var.values):
       name = "update_%d" % i
-      with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
+      with ops.device(v.device), \
+           distribute_lib.UpdateContext(i), \
+           ops.name_scope(name):
         # If args and kwargs are not mirrored, the value is returned as is.
         updates.append(fn(v,
-                          *values.select_device_mirrored(d, args),
-                          **values.select_device_mirrored(d, kwargs)))
-    return values.update_regroup(self, self._device_map, updates, group)
+                          *values.select_replica_mirrored(i, args),
+                          **values.select_replica_mirrored(i, kwargs)))
+    return values.update_regroup(self, updates, group)
 
   def read_var(self, var):
     assert isinstance(var, values.TPUVariableMixin) or isinstance(
@@ -706,8 +704,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
             nest.pack_sequence_as(result[0], nest.flatten(replica_output))
             for replica_output in replicate_outputs
         ]
-      device_map = self._device_map  # pylint: disable=protected-access
-      return values.regroup(device_map, replicate_outputs)
+      return values.regroup(replicate_outputs)
 
     if context.executing_eagerly():
       tpu_function = def_function.function(tpu_function)
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 0c2a9ccdaac..df232545cfa 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -21,7 +21,6 @@ from __future__ import print_function
 import collections
 import contextlib
 import weakref
-import six
 
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribute_lib
@@ -44,325 +43,76 @@ from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import nest
 
 
-def _devices_match(d1, d2):
-  return device_util.canonicalize(d1) == device_util.canonicalize(d2)
-
-
-class DeviceMap(object):
-  """A mapping of replicas & logical device ids to devices."""
-
-  @property
-  def all_devices(self):
-    """Returns a tuple of strings with all devices in this DeviceMap."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  @property
-  def devices_by_replica(self):
-    """Returns a tuple `t` where `t[replica]` is the devices for `replica`."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  @property
-  def num_logical_devices(self):
-    """Count of the number of devices each replica may be defined across."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  @property
-  def num_replicas_in_graph(self):
-    """Number of replicas defined in this graph."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  def logical_device_from_values(self, values):
-    """Returns the logical device index `values` is on."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  def logical_to_actual_devices(self, logical_device_id):
-    """Returns sequence of `num_replicas_in_graph` devices."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  def select_for_current_replica(self, values, replica_context):
-    """Select the element of `values` for the current replica."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  def replica_for_device(self, device):
-    """Return the replica id containing `device`."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  def select_for_device(self, values, device):
-    """Select the element of `values` to access from `device`."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-  def is_device_in_replica(self, device, replica_id):
-    """Returns whether `device` is a member of replica `replica_id`."""
-    raise NotImplementedError("Required for DeviceMap implementations.")
-
-
-class SingleDeviceMap(DeviceMap):
-  """A device map for 1 non-computation device.
-
-  Use `SingleDeviceMap` when the device does not correspond to some replica of
-  the computation. For computation devices, use `ReplicaDeviceMap` below (even
-  if there is only a single device in the map).
-  """
-
-  def __init__(self, device):
-    """Initialize a `SingleDeviceMap`.
-
-    Args:
-      device: A string device.
-    """
-    assert isinstance(device, six.string_types)
-    self._device = device_util.canonicalize(device)
-    self._devices = (self._device,)
-
-  @property
-  def all_devices(self):
-    return self._devices
-
-  @property
-  def devices_by_replica(self):
-    raise ValueError("SingleDeviceMap not indexed by replicas")
-
-  @property
-  def num_logical_devices(self):
-    return 1
-
-  @property
-  def num_replicas_in_graph(self):
-    return 1
-
-  def logical_device_from_values(self, values):
-    del values
-    return 0
-
-  def logical_to_actual_devices(self, logical_device_id):
-    assert logical_device_id == 0
-    return self._devices
-
-  def select_for_current_replica(self, values, replica_context):
-    assert len(values) == 1
-    del replica_context
-    return values[0]
-
-  def replica_for_device(self, device):
-    raise ValueError("SingleDeviceMap not indexed by replicas")
-
-  def select_for_device(self, values, device):
-    assert len(values) == 1
-    if self._device != device:
-      raise ValueError("Device %s not found in %s (current device %s)" %
-                       (device, self._devices, device_util.current()))
-    return values[0]
-
-  def is_device_in_replica(self, device, replica_id):
-    raise ValueError("SingleDeviceMap not indexed by replicas")
-
-  def __repr__(self):
-    return "%s(%r)" % (self.__class__.__name__, self._device)
-
-
-class ReplicaDeviceMap(DeviceMap):
-  """A device map for 1 device per replica."""
-
-  def __init__(self, devices):
-    """Initialize a `ReplicaDeviceMap`.
-
-    Args:
-      devices: `devices[i]` is the string device for replica `i`.
-    """
-    self._devices = tuple(device_util.canonicalize(d) for d in devices)
-    if len(set(self._devices)) != len(self._devices):
-      raise ValueError("Duplicate devices in %s, after canonicalization: %s" %
-                       (devices, self._devices))
-    self._device_to_replica = {d: r for r, d in enumerate(self._devices)}
-
-  @property
-  def all_devices(self):
-    return self._devices
-
-  @property
-  def devices_by_replica(self):
-    return ((d,) for d in self._devices)
-
-  @property
-  def num_logical_devices(self):
-    return 1
-
-  @property
-  def num_replicas_in_graph(self):
-    return len(self._devices)
-
-  def logical_device_from_values(self, values):
-    del values
-    return 0
-
-  def logical_to_actual_devices(self, logical_device_id):
-    assert logical_device_id == 0
-    return self._devices
-
-  def select_for_current_replica(self, values, replica_context):
-    assert len(values) == len(self._devices)
+def _get_current_replica_id_as_int():
+  """Returns the current replica ID as an integer, or `None`."""
+  replica_context = distribution_strategy_context.get_replica_context()
+  if replica_context:
     replica_id = replica_context.replica_id_in_sync_group
     if not isinstance(replica_id, int):
       replica_id = tensor_util.constant_value(replica_id)
-    if replica_id is None:
-      replica_id = 0
-    return values[replica_id]
-
-  def replica_for_device(self, device):
-    return self._device_to_replica.get(device)
-
-  def select_for_device(self, values, device):
-    assert len(values) == len(self._devices)
-    replica_id = self._device_to_replica.get(device)
-    if replica_id is None:
-      raise ValueError("Device %s not found in %s (current device %s)" %
-                       (device, self._devices, device_util.current()))
-    return values[replica_id]
-
-  def is_device_in_replica(self, device, replica_id):
-    return _devices_match(device, self._devices[replica_id])
-
-  def __str__(self):
-    return "[%s]" % (", ".join(self._devices))
-
-  def __repr__(self):
-    return "%s([%s])" % (self.__class__.__name__, ", ".join(
-        repr(d) for d in self._devices))
-
-
-LogicalDeviceSpec = collections.namedtuple("LogicalDeviceSpec",
-                                           ("device_map", "logical_device"))
-
-
-class WorkerDeviceMap(DeviceMap):
-  """A device map for one value per worker."""
-
-  def __init__(self, devices, num_replicas_per_worker):
-    """Initialize a `WorkerDeviceMap`.
-
-    Args:
-      devices: `devices[i]` is the string device for worker `i` in in-graph
-        relication case; devices is single-element list for its corresponding
-        worker in between-graph case.
-      num_replicas_per_worker: number of replicas per worker, useful in in-graph
-        replication case.
-    """
-    self._devices = tuple(device_util.canonicalize(d) for d in devices)
-    if len(set(self._devices)) != len(self._devices):
-      raise ValueError("Duplicate devices in %s, after canonicalization: %s" %
-                       (devices, self._devices))
-    self._num_replicas_per_worker = num_replicas_per_worker
-
-  @property
-  def all_devices(self):
-    return self._devices
-
-  @property
-  def devices_by_replica(self):
-    raise ValueError("`WorkerDeviceMap` is not indexed by replicas")
-
-  @property
-  def num_logical_devices(self):
-    return 1
-
-  @property
-  def num_replicas_in_graph(self):
-    return len(self._devices)
-
-  def logical_device_from_values(self, values):
-    del values
-    return 0
-
-  def logical_to_actual_devices(self, logical_device_id):
-    assert logical_device_id == 0
-    return self._devices
-
-  def select_for_current_replica(self, values, replica_context):
-    return values[replica_context.replica_id_in_sync_group //
-                  self._num_replicas_per_worker]
-
-  def replica_for_device(self, device):
-    raise ValueError("`WorkerDeviceMap` not indexed by replicas")
-
-  def select_for_device(self, values, device):
-    # TODO(yuefengz): this should map from any device to the value on its
-    # corresponding worker.
-    return values[self._devices.index(device_util.canonicalize(device))]
-
-  def is_device_in_replica(self, device, replica_id):
-    raise ValueError("WorkerDeviceMap not indexed by replicas")
-
-  def __repr__(self):
-    return "%s(%r, num_replicas_per_worker=%d)" % (
-        self.__class__.__name__, self._devices, self._num_replicas_per_worker)
+  else:
+    replica_id = distribute_lib.get_update_replica_id()
+  return replica_id
 
 
 class DistributedValues(object):
   """Holds a map from replica to values. Either PerReplica or Mirrored."""
 
-  def __init__(self, device_map, values, logical_device=None):
-    assert isinstance(device_map, DeviceMap)
-    self._device_map = device_map
+  def __init__(self, values):
     self._values = tuple(values)
-    if logical_device is None:
-      logical_device = device_map.logical_device_from_values(self._values)
-    self._logical_device = logical_device
 
-  # TODO(josh11b): Split this into two functions, one with device, one without.
-  def get(self, device=None):
+  def get(self):
     """Returns the value for the current device or raises a ValueError."""
-    if device is None:
-      replica_context = distribution_strategy_context.get_replica_context()
-      if replica_context:
-        return self._device_map.select_for_current_replica(
-            self._values, replica_context)
-      else:
-        update_replica_id = distribute_lib.get_update_replica_id()
-        if update_replica_id is None:
-          return self._get_cross_replica()
-        else:
-          return self._values[update_replica_id]
-    device = device_util.canonicalize(device)
-    return self._device_map.select_for_device(self._values, device)
+    replica_id = _get_current_replica_id_as_int()
+    if replica_id is None:
+      return self._get_cross_replica()
+    else:
+      return self._values[replica_id]
+
+  def _get_cross_replica(self):
+    raise NotImplementedError(
+        "This method should be overridden by sub-classes which support cross-"
+        "replica accesses.")
+
+  def _get_closest(self):
+    """Returns value in same replica or device if possible, else the primary."""
+    replica_id = _get_current_replica_id_as_int()
+    if replica_id is None:
+      # Try to find a value on the current device.
+      current_device = device_util.canonicalize(device_util.current())
+      for value in self._values:
+        if device_util.canonicalize(value.device) == current_device:
+          return value
+      return self.primary
+    else:
+      return self._values[replica_id]
 
   @property
   def primary(self):
     """Returns a representative component."""
     return self._values[0]
 
-  @property
-  def devices(self):
-    return self._device_map.logical_to_actual_devices(self._logical_device)
-
-  @property
-  def logical_device(self):
-    return self._logical_device
-
-  @property
-  def device_map(self):
-    return self._device_map
-
   # TODO(josh11b): Replace experimental_local_results with this?
   @property
   def values(self):
     return self._values
 
+  @property
+  def devices(self):
+    return tuple(v.device for v in self._values)
+
   @property
   def is_tensor_like(self):
     return all(tensor_util.is_tensor(v) for v in self._values)
 
   def __str__(self):
-    devices = self.devices
-    assert len(self._values) == len(devices)
-    debug_str = ",\n".join("  %d %s: %s" % (i, devices[i], self._values[i])
-                           for i in range(len(devices)))
+    debug_str = ",\n".join(
+        "  %d: %s" % (i, v) for i, v in enumerate(self._values))
     return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
 
   def __repr__(self):
-    devices = self.devices
-    assert len(self._values) == len(devices)
-    debug_repr = ",\n".join("  %d %s: %r" % (i, devices[i], self._values[i])
-                            for i in range(len(devices)))
+    debug_repr = ",\n".join(
+        "  %d: %r" % (i, v) for i, v in enumerate(self._values))
     return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
 
 
@@ -523,28 +273,22 @@ class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
 
   @property
   def _type_spec(self):
-    value_specs = nest.map_structure(type_spec.type_spec_from_value,
-                                     self._values)
-    return PerReplicaSpec(value_specs, self._device_map, self._logical_device)
+    return PerReplicaSpec(
+        *(type_spec.type_spec_from_value(v) for v in self._values))
 
 
 class PerReplicaSpec(type_spec.TypeSpec):
   """Type specification for a `PerReplica`."""
 
-  __slots__ = ["_value_specs", "_device_map", "_logical_device"]
+  __slots__ = ["_value_specs"]
 
   value_type = property(lambda self: PerReplica)
 
-  def __init__(self, value_specs, device_map, logical_device):
-    if isinstance(device_map, tuple):
-      device_map = self._deserialize_device_map(device_map)
+  def __init__(self, *value_specs):
     self._value_specs = tuple(value_specs)
-    self._device_map = device_map
-    self._logical_device = logical_device
 
   def _serialize(self):
-    device_map = self._serialize_device_map(self._device_map)
-    return (self._value_specs, device_map, self._logical_device)
+    return self._value_specs
 
   @property
   def _component_specs(self):
@@ -559,34 +303,7 @@ class PerReplicaSpec(type_spec.TypeSpec):
     return value._values  # pylint: disable=protected-access
 
   def _from_components(self, tensor_list):
-    return PerReplica(
-        self._device_map, tensor_list, logical_device=self._logical_device)
-
-  @staticmethod
-  def _serialize_device_map(device_map):
-    if isinstance(device_map, SingleDeviceMap):
-      return ("single", device_map.all_devices[0])
-    elif isinstance(device_map, ReplicaDeviceMap):
-      return ("replica", device_map.all_devices)
-    elif isinstance(device_map, WorkerDeviceMap):
-      return ("worker", device_map.all_devices,
-              device_map.num_replicas_per_worker)
-    else:
-      raise ValueError("PerReplicaSpec does not support device_map type %s" %
-                       type(device_map).__name__)
-
-  @staticmethod
-  def _deserialize_device_map(device_map_info):
-    device_map_type = device_map_info[0]
-    device_map_args = device_map_info[1:]
-    if device_map_type == "single":
-      return SingleDeviceMap(*device_map_args)
-    elif device_map_type == "replica":
-      return ReplicaDeviceMap(*device_map_args)
-    elif device_map_type == "worker":
-      return WorkerDeviceMap(*device_map_args)
-    else:
-      raise ValueError("Unexpected value in state tuple")
+    return PerReplica(tensor_list)
 
 
 # Note that unlike PerReplica, Mirrored values inherit from
@@ -596,11 +313,7 @@ class Mirrored(DistributedDelegate):
   """Holds a map from replica to values which are kept in sync."""
 
   def _get_cross_replica(self):
-    device = device_util.canonicalize(device_util.current())
-    replica_id = self._device_map.replica_for_device(device)
-    if replica_id is None:
-      return self.primary
-    return self._values[replica_id]
+    return self._get_closest()
 
   def _as_graph_element(self):
     obj = self.get()
@@ -656,10 +369,9 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
   # TODO(josh11b): Support changing the set of variables if e.g. if new
   # devices are joining or a device is to leave.
 
-  def __init__(self, strategy, device_map, values, logical_device=None):
+  def __init__(self, strategy, values):
     self._distribute_strategy = strategy
-    super(DistributedVariable, self).__init__(
-        device_map, values, logical_device=logical_device)
+    super(DistributedVariable, self).__init__(values)
     self._common_name = self.primary.name.split(":")[0]
     # Use a weakref to make it easy to map from the contained values
     # to the container without introducing a reference cycle.
@@ -709,21 +421,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
           tuple(v.initializer for v in self._values))
     return init_op
 
-  def _get_closest(self):
-    """Return member in the same replica if possible, else the primary."""
-    replica_context = distribution_strategy_context.get_replica_context()
-    if replica_context:
-      return self._device_map.select_for_current_replica(
-          self._values, replica_context)
-    update_replica_id = distribute_lib.get_update_replica_id()
-    if update_replica_id is not None:
-      return self._values[update_replica_id]
-    device = device_util.canonicalize(device_util.current())
-    replica_id = self._device_map.replica_for_device(device)
-    if replica_id is None:
-      return self.primary
-    return self._values[replica_id]
-
   def initialized_value(self):
     return self._get_closest().initialized_value()
 
@@ -766,14 +463,12 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
 
   @property
   def handle(self):
-    replica_context = distribution_strategy_context.get_replica_context()
-    if replica_context is None:
-      update_replica_id = distribute_lib.get_update_replica_id()
-      if update_replica_id is None:
-        raise ValueError("`handle` is not available outside the replica context"
-                         " or a `tf.distribute.Strategy.update()` call.")
-      return self._values[update_replica_id].handle
-    return self.get().handle
+    replica_id = _get_current_replica_id_as_int()
+    if replica_id is None:
+      raise ValueError("`handle` is not available outside the replica context"
+                       " or a `tf.distribute.Strategy.update()` call.")
+    else:
+      return self._values[replica_id].handle
 
   def eval(self, session=None):
     return self._get_closest().eval(session)
@@ -883,9 +578,9 @@ class TPUVariableMixin(object):
       raise AttributeError(
           "'{}' not accessible within a TPU context.".format(name))
 
-  def get(self, device=None):
-    if (_enclosing_tpu_context() is None) or (device is not None):
-      return super(TPUVariableMixin, self).get(device=device)
+  def get(self):
+    if _enclosing_tpu_context() is None:
+      return super(TPUVariableMixin, self).get()
     else:
       raise NotImplementedError(
           "`TPUVariableMixin.get()` is not supported within a TPU context.")
@@ -917,10 +612,8 @@ class TPUVariableMixin(object):
     if tpu_context is None:
       return self._get_closest().handle
     else:
-      return tpu_context.get_replicated_var_handle(self._handle_id,
-                                                   self._values,
-                                                   self._device_map,
-                                                   self._is_mirrored())
+      return tpu_context.get_replicated_var_handle(
+          self._handle_id, self._values, self._is_mirrored())
 
   @property
   def device(self):
@@ -1035,8 +728,8 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
 
 
 def create_mirrored_variable(  # pylint: disable=missing-docstring
-    strategy, device_map, logical_device, real_mirrored_creator, mirrored_cls,
-    sync_on_read_cls, *args, **kwargs):
+    strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls,
+    *args, **kwargs):
   # Figure out what collections this variable should be added to.
   # We'll add the MirroredVariable to those collections instead.
   var_collections = kwargs.pop("collections", None)
@@ -1079,17 +772,9 @@ def create_mirrored_variable(  # pylint: disable=missing-docstring
   # was never recorded on the tape instead of having to do this manually
   # here.
   with tape.stop_recording():
-    devices = device_map.logical_to_actual_devices(logical_device)
-    value_list = real_mirrored_creator(devices, *args, **kwargs)
-
+    value_list = real_mirrored_creator(*args, **kwargs)
     var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
-
-    result = var_cls(
-        strategy,
-        device_map,
-        value_list,
-        aggregation,
-        logical_device=logical_device)
+    result = var_cls(strategy, value_list, aggregation)
 
   # Add the wrapped variable to the requested collections.
   # The handling of eager mode and the global step matches
@@ -1120,14 +805,8 @@ def create_mirrored_variable(  # pylint: disable=missing-docstring
 class MirroredVariable(DistributedVariable, Mirrored):
   """Holds a map from replica to variables whose values are kept in sync."""
 
-  def __init__(self,
-               strategy,
-               device_map,
-               values,
-               aggregation,
-               logical_device=None):
-    super(MirroredVariable, self).__init__(
-        strategy, device_map, values, logical_device=logical_device)
+  def __init__(self, strategy, values, aggregation):
+    super(MirroredVariable, self).__init__(strategy, values)
     self._aggregation = aggregation
 
   # The arguments to update() are automatically unwrapped so the update()
@@ -1187,17 +866,12 @@ class MirroredVariable(DistributedVariable, Mirrored):
     return self._aggregation
 
   def _get_cross_replica(self):
-    device = device_util.canonicalize(device_util.current())
-    replica_id = self._device_map.replica_for_device(device)
-    if replica_id is None:
-      return array_ops.identity(self.primary)
-    return array_ops.identity(self._values[replica_id])
+    # Return identity, to avoid directly exposing the variable to the user and
+    # allowing it to be modified by mistake.
+    return array_ops.identity(Mirrored._get_cross_replica(self))
 
   def _as_graph_element(self):
-    # pylint: disable=protected-access
-    if distribution_strategy_context.in_cross_replica_context():
-      return self.primary._as_graph_element()
-    return self.get()._as_graph_element()
+    return self._get_closest()._as_graph_element()  # pylint: disable=protected-access
 
   def _gather_saveables_for_checkpoint(self):
     """Overrides Trackable method.
@@ -1344,15 +1018,9 @@ def _assert_replica_context(strategy):
 class SyncOnReadVariable(DistributedVariable):
   """Holds a map from replica to variables whose values are reduced on save."""
 
-  def __init__(self,
-               strategy,
-               device_map,
-               values,
-               aggregation,
-               logical_device=None):
+  def __init__(self, strategy, values, aggregation):
+    super(SyncOnReadVariable, self).__init__(strategy, values)
     self._aggregation = aggregation
-    super(SyncOnReadVariable, self).__init__(
-        strategy, device_map, values, logical_device=logical_device)
 
   def assign_sub(self, *args, **kwargs):
     with _enter_or_assert_strategy(self._distribute_strategy):
@@ -1392,7 +1060,7 @@ class SyncOnReadVariable(DistributedVariable):
         # when saving.
         tensor = args[0]
         if self._aggregation == vs.VariableAggregation.SUM:
-          tensor = math_ops.cast(tensor / len(self.devices), self.dtype)
+          tensor = math_ops.cast(tensor / len(self._values), self.dtype)
         return control_flow_ops.group(
             tuple(_assign_on_device(v.device, v, tensor) for v in self._values))
       else:
@@ -1479,10 +1147,8 @@ class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
     return False
 
 
-def regroup(device_map, values, wrap_class=PerReplica):
+def regroup(values, wrap_class=PerReplica):
   """Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
-  assert isinstance(device_map, DeviceMap)
-  assert len(values) == device_map.num_replicas_in_graph
   v0 = values[0]
 
   if isinstance(v0, list):
@@ -1491,8 +1157,7 @@ def regroup(device_map, values, wrap_class=PerReplica):
       assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
                                  (len(v), len(v0), v, v0))
     return [
-        regroup(device_map, tuple(v[i]
-                                  for v in values), wrap_class)
+        regroup(tuple(v[i] for v in values), wrap_class)
         for i in range(len(v0))
     ]
 
@@ -1501,8 +1166,7 @@ def regroup(device_map, values, wrap_class=PerReplica):
       assert isinstance(v, tuple)
       assert len(v) == len(v0)
     regrouped_tuple = tuple(
-        regroup(device_map, tuple(v[i]
-                                  for v in values), wrap_class)
+        regroup(tuple(v[i] for v in values), wrap_class)
         for i in range(len(v0)))
     if hasattr(v0, "_fields"):
       # This tuple is in fact a namedtuple! Create a new namedtuple instance
@@ -1519,7 +1183,7 @@ def regroup(device_map, values, wrap_class=PerReplica):
       assert set(v.keys()) == v0keys, ("v[0].keys: %s  v[i].keys: %s" %
                                        (v0keys, set(v.keys())))
     return {
-        key: regroup(device_map, tuple(v[key] for v in values), wrap_class)
+        key: regroup(tuple(v[key] for v in values), wrap_class)
         for key in v0keys
     }
 
@@ -1555,20 +1219,14 @@ def regroup(device_map, values, wrap_class=PerReplica):
     # pylint: disable=protected-access
     assert not isinstance(v0, MirroredVariable), (
         "ids = %s, values = %s" % ([id(v) for v in values], values))
-    assert device_map.is_device_in_replica(
-        v0.device,
-        0), ("v0.device = %s, device_map = %s" % (v0.device, device_map))
     distributed_container = v0._distributed_container()
     assert distributed_container is not None
-    for r, v in enumerate(values[1:]):
-      assert device_map.is_device_in_replica(
-          v.device, r + 1), ("v.device = %s, r = %d, device_map = %s" %
-                             (v.device, r + 1, device_map))
+    for v in values[1:]:
       assert distributed_container is v._distributed_container()
     return distributed_container
   # pylint: enable=protected-access
 
-  return wrap_class(device_map, values)
+  return wrap_class(values)
 
 
 def select_replica(replica_id, structured):
@@ -1587,8 +1245,8 @@ def select_replica(replica_id, structured):
   return nest.map_structure(_get, structured)
 
 
-def select_device_mirrored(device, structured):
-  """Specialize a nest of regular & mirrored values for one device."""
+def select_replica_mirrored(replica_id, structured):
+  """Specialize a nest of regular & mirrored values for one replica."""
 
   def _get_mirrored(x):
     if isinstance(x, DistributedValues):
@@ -1596,23 +1254,23 @@ def select_device_mirrored(device, structured):
         raise TypeError(
             "Expected value to be mirrored across replicas: %s in %s." %
             (x, structured))
-      return x.get(device)
+      return x.values[replica_id]
     else:
       return x
 
   return nest.map_structure(_get_mirrored, structured)
 
 
-def update_regroup(extended, device_map, updates, group):
+def update_regroup(extended, updates, group):
   """Regroup for an update, with dependencies to ensure all updates execute."""
   if not group:
-    regrouped = regroup(device_map, updates, Mirrored)
+    regrouped = regroup(updates, Mirrored)
     return nest.map_structure(extended._local_results, regrouped)  # pylint: disable=protected-access
 
-  def _make_grouped_mirrored(device_map, values):
+  def _make_grouped_mirrored(values):
     """Convert per-replica list `values` into Mirrored type with grouping."""
     if len(values) == 1:
-      return Mirrored(device_map, values)
+      return Mirrored(values)
 
     # Make sure we run all updates. Without this, something like
     # session.run(extended.update(...)) may only update one replica.
@@ -1626,17 +1284,14 @@ def update_regroup(extended, device_map, updates, group):
 
     # Otherwise we need tensors with the same values as `values`, but
     # that have a dependency on `g`.
-    devices = device_map.logical_to_actual_devices(
-        device_map.logical_device_from_values(values))
-    assert len(values) == len(devices)
     with_dep = []
-    for v, d in zip(values, devices):
-      with ops.device(d), ops.control_dependencies([g]):
+    for v in values:
+      with ops.device(v.device), ops.control_dependencies([g]):
         with_dep.append(array_ops.identity(v))
 
-    return Mirrored(device_map, with_dep)
+    return Mirrored(with_dep)
 
-  return regroup(device_map, updates, _make_grouped_mirrored)
+  return regroup(updates, _make_grouped_mirrored)
 
 
 def value_container(val):
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index d97d1155c82..01022b6e110 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -24,7 +24,7 @@ import os
 from absl.testing import parameterized
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.distribute import combinations
-from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import tpu_strategy
@@ -55,63 +55,35 @@ from tensorflow.python.util import nest
 class DistributedValuesTest(test.TestCase):
 
   def testGetEager(self):
-    with ops.device("/device:CPU:0"):
-      one = constant_op.constant(1)
-      two = constant_op.constant(2)
-      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
-      v = values.DistributedValues(device_map, (one, two))
-      self.assertEqual(two, v.get("/device:GPU:0"))
-      self.assertEqual(one, v.get())
-      with self.assertRaises(ValueError):
-        self.assertIsNone(v.get("/device:GPU:2"))
+    one = constant_op.constant(1)
+    two = constant_op.constant(2)
+    v = values.DistributedValues((one, two))
+    self.assertEqual(one, v.get())
+    with distribute_lib.ReplicaContext(None, 1):
+      self.assertEqual(two, v.get())
 
   def testGetGraph(self):
-    with context.graph_mode(), \
-        ops.Graph().as_default(), \
-        ops.device("/device:CPU:0"):
+    with context.graph_mode(), ops.Graph().as_default():
       one = constant_op.constant(1)
       two = constant_op.constant(2)
-      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
-      v = values.DistributedValues(device_map, (one, two))
-      self.assertEqual(two, v.get("/device:GPU:0"))
+      v = values.DistributedValues((one, two))
       self.assertEqual(one, v.get())
-      with self.assertRaises(ValueError):
-        self.assertIsNone(v.get("/device:GPU:2"))
-
-  def testCanonicalization(self):
-    canonical_cpu = ("/job:localhost/replica:0/task:0/device:CPU:0",)
-    v = values.DistributedValues(values.SingleDeviceMap(""), (42,))
-    self.assertEqual(canonical_cpu, v.devices)
-    v = values.DistributedValues(values.SingleDeviceMap("/device:CPU:0"), (42,))
-    self.assertEqual(canonical_cpu, v.devices)
-    v = values.DistributedValues(values.SingleDeviceMap("/cpu:0"), (42,))
-    self.assertEqual(canonical_cpu, v.devices)
-    v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,))
-    self.assertEqual(canonical_cpu, v.devices)
+      with distribute_lib.ReplicaContext(None, 1):
+        self.assertEqual(two, v.get())
 
   def testIsTensorLike(self):
-    with context.graph_mode(), \
-         ops.Graph().as_default(), \
-         ops.device("/device:CPU:0"):
+    with context.graph_mode(), ops.Graph().as_default():
       one = constant_op.constant(1)
       two = constant_op.constant(2)
-      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
-      v = values.DistributedValues(device_map, (one, two))
-      self.assertEqual(two, v.get("/device:GPU:0"))
-      self.assertEqual(one, v.get())
+      v = values.DistributedValues((one, two))
       self.assertTrue(v.is_tensor_like)
       self.assertTrue(tensor_util.is_tensor(v))
 
   def testIsTensorLikeWithAConstant(self):
-    with context.graph_mode(), \
-         ops.Graph().as_default(), \
-         ops.device("/device:CPU:0"):
+    with context.graph_mode(), ops.Graph().as_default():
       one = constant_op.constant(1)
       two = 2.0
-      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
-      v = values.DistributedValues(device_map, (one, two))
-      self.assertEqual(two, v.get("/device:GPU:0"))
-      self.assertEqual(one, v.get())
+      v = values.DistributedValues((one, two))
       self.assertFalse(v.is_tensor_like)
       self.assertFalse(tensor_util.is_tensor(v))
 
@@ -120,62 +92,59 @@ class DistributedDelegateTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testGetAttr(self):
-    with ops.device("/device:CPU:0"):
+    class Foo(object):
 
-      class Foo(object):
+      def __init__(self, x):
+        self.x = x
 
-        def __init__(self, x):
-          self.x = x
-
-      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
-      v = values.DistributedDelegate(device_map, (Foo(7), Foo(8)))
-      self.assertEqual(7, v.x)
-      with self.assertRaises(AttributeError):
-        _ = v.y
+    v = values.DistributedDelegate((Foo(7), Foo(8)))
+    self.assertEqual(7, v.x)
+    with self.assertRaises(AttributeError):
+      _ = v.y
 
   @test_util.run_in_graph_and_eager_modes
   def testOperatorOverride(self):
-    with ops.device("/device:CPU:0"):
-      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
-      v = values.DistributedDelegate(device_map, (7, 8))
-      # v should act like int(7).
-      self.assertEqual(8, v + 1)
-      self.assertEqual(10, 3 + v)
-      self.assertEqual(14, v + v)
-      self.assertEqual(5, v - 2)
-      self.assertEqual(6, 13 - v)
-      self.assertEqual(0, v - v)
-      self.assertEqual(14, v * 2)
-      self.assertEqual(21, 3 * v)
-      self.assertEqual(49, v * v)
-      self.assertEqual(3.5, v / 2)
-      self.assertEqual(1.5, 10.5 / v)
-      self.assertEqual(3, v // 2)
-      self.assertEqual(2, 15 // v)
-      self.assertEqual(1, v % 2)
-      self.assertEqual(2, 16 % v)
-      self.assertTrue(v < 12)
-      self.assertTrue(v <= 12)
-      self.assertFalse(v > 12)
-      self.assertFalse(v >= 12)
-      self.assertFalse(12 < v)
-      self.assertFalse(12 <= v)
-      self.assertTrue(12 > v)
-      self.assertTrue(12 >= v)
-      self.assertEqual(3, v & 3)
-      self.assertEqual(3, 11 & v)
-      self.assertEqual(15, v | 8)
-      self.assertEqual(23, 16 | v)
-      self.assertEqual(4, v ^ 3)
-      self.assertEqual(12, 11 ^ v)
-      self.assertEqual(343, pow(v, 3))
-      self.assertEqual(3, pow(v, 3, 10))
-      self.assertEqual(128, pow(2, v))
-      self.assertEqual(-7, -v)
-      self.assertEqual(~7, ~v)
-      self.assertEqual(7, abs(v))
-      with self.assertRaises(TypeError):
-        _ = v[2]
+    v = values.DistributedDelegate((7, 8))
+    # v should act like int(7).
+    self.assertEqual(8, v + 1)
+    self.assertEqual(10, 3 + v)
+    self.assertEqual(14, v + v)
+    self.assertEqual(5, v - 2)
+    self.assertEqual(6, 13 - v)
+    self.assertEqual(0, v - v)
+    self.assertEqual(14, v * 2)
+    self.assertEqual(21, 3 * v)
+    self.assertEqual(49, v * v)
+    self.assertEqual(3.5, v / 2)
+    self.assertEqual(1.5, 10.5 / v)
+    self.assertEqual(3, v // 2)
+    self.assertEqual(2, 15 // v)
+    self.assertEqual(1, v % 2)
+    self.assertEqual(2, 16 % v)
+    # pylint: disable=g-generic-assert
+    self.assertTrue(v < 12)
+    self.assertTrue(v <= 12)
+    self.assertFalse(v > 12)
+    self.assertFalse(v >= 12)
+    self.assertFalse(12 < v)
+    self.assertFalse(12 <= v)
+    self.assertTrue(12 > v)
+    self.assertTrue(12 >= v)
+    # pylint: enable=g-generic-assert
+    self.assertEqual(3, v & 3)
+    self.assertEqual(3, 11 & v)
+    self.assertEqual(15, v | 8)
+    self.assertEqual(23, 16 | v)
+    self.assertEqual(4, v ^ 3)
+    self.assertEqual(12, 11 ^ v)
+    self.assertEqual(343, pow(v, 3))
+    self.assertEqual(3, pow(v, 3, 10))
+    self.assertEqual(128, pow(2, v))
+    self.assertEqual(-7, -v)
+    self.assertEqual(~7, ~v)
+    self.assertEqual(7, abs(v))
+    with self.assertRaises(TypeError):
+      _ = v[2]
 
 
 def _device_str(d):
@@ -185,15 +154,15 @@ def _device_str(d):
 def _nested_value(d):
   return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
 
+
 def _make_mirrored_val(init_val=5.0):
   v = []
   devices = ["/device:GPU:0", "/device:CPU:0"]
   for d, _ in zip(devices, ["v", "v/replica"]):
     with ops.device(d):
       v.append(constant_op.constant(init_val))
-  device_map = values.ReplicaDeviceMap(devices)
-  mirrored = values.Mirrored(device_map, v)
-  return mirrored
+  return values.Mirrored(v)
+
 
 def _make_mirrored():
   v = []
@@ -202,29 +171,20 @@ def _make_mirrored():
     with ops.device(d):
       v.append(variable_scope.get_variable(
           name=n, initializer=init, use_resource=True))
-  device_map = values.ReplicaDeviceMap(devices)
-  mirrored = values.MirroredVariable(None, device_map, v,
-                                     variable_scope.VariableAggregation.SUM)
-  return v, device_map, mirrored
+  mirrored = values.MirroredVariable(
+      None, v, variable_scope.VariableAggregation.SUM)
+  return mirrored
 
 
 class RegroupAndSelectDeviceTest(test.TestCase):
 
   def _is_per_replica(self, result, expected, klass=values.PerReplica):
     self.assertIsInstance(result, klass)
-    # We canonicalize the devices to match the device strings returned
-    # by PerReplica, which also does device string canonicalization.
-    devices = [device_util.canonicalize(_device_str(i))
-               for i in range(len(expected))]
-    self.assertEqual(set(devices), set(result.devices))
-    for i, d in enumerate(devices):
-      self.assertEqual(expected[i], result.get(d))
-      self.assertEqual(expected[i], result.get(_device_str(i)))
+    for i, exp in enumerate(expected):
+      self.assertEqual(exp, result.values[i])
 
   def testNested(self):
-    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
-    result = values.regroup(device_map,
-                            (_nested_value("1"), _nested_value("2")))
+    result = values.regroup((_nested_value("1"), _nested_value("2")))
     self.assertIsInstance(result, tuple)
     self.assertEqual(3, len(result))
     self._is_per_replica(result[0], ["a1", "a2"])
@@ -247,16 +207,14 @@ class RegroupAndSelectDeviceTest(test.TestCase):
                      values.select_replica(1, result))
     # select_device_mirrored() should fail due to non-mirrored values
     with self.assertRaises(TypeError):
-      values.select_device_mirrored(_device_str(0), result)
+      values.select_replica_mirrored(0, result)
     with self.assertRaises(TypeError):
-      values.select_device_mirrored(_device_str(1), result)
+      values.select_replica_mirrored(1, result)
 
   def testWrapClass(self):
     # Normally a mirrored value would be the same across devices, but
     # for a test it is convenient to be able to tell the values apart.
-    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
-    result = values.regroup(device_map,
-                            (_nested_value("1"), _nested_value("2")),
+    result = values.regroup((_nested_value("1"), _nested_value("2")),
                             values.Mirrored)
     self.assertIsInstance(result, tuple)
     self.assertEqual(3, len(result))
@@ -280,13 +238,12 @@ class RegroupAndSelectDeviceTest(test.TestCase):
                      values.select_replica(1, result))
     # Values are marked as mirrored, so select_device_mirrored() is allowed.
     self.assertEqual(_nested_value("1"),
-                     values.select_device_mirrored(_device_str(0), result))
+                     values.select_replica_mirrored(0, result))
     self.assertEqual(_nested_value("2"),
-                     values.select_device_mirrored(_device_str(1), result))
+                     values.select_replica_mirrored(1, result))
 
   def testWrapAListOfTwoTuples(self):
-    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
-    result = values.regroup(device_map, [("1", "2"), ("3", "4")])
+    result = values.regroup([("1", "2"), ("3", "4")])
     self.assertIsInstance(result, tuple)
     self.assertEqual(2, len(result))
     self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
@@ -295,14 +252,13 @@ class RegroupAndSelectDeviceTest(test.TestCase):
   def testMirroredContainer(self):
     if context.num_gpus() < 1 and context.executing_eagerly():
       self.skipTest("A GPU is not available for this test in eager mode.")
-    v, device_map, mirrored = _make_mirrored()
-    result = values.regroup(device_map, v)
+    mirrored = _make_mirrored()
+    result = values.regroup(mirrored.values)
     self.assertIs(mirrored, result)
 
   def testSameId(self):
     foo = object()
-    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
-    result = values.regroup(device_map, (("a", foo), ("b", foo)))
+    result = values.regroup((("a", foo), ("b", foo)))
     self.assertIsInstance(result, tuple)
     self.assertEqual(2, len(result))
     self._is_per_replica(result[0], ["a", "b"])
@@ -321,8 +277,7 @@ class RegroupAndSelectDeviceTest(test.TestCase):
     self.assertIs(foo, result_1[1])
 
   def testOneDevice(self):
-    device_map = values.ReplicaDeviceMap((_device_str(0),))
-    result = values.regroup(device_map, (_nested_value("1"),))
+    result = values.regroup((_nested_value("1"),))
     # On one device regroup() and select_replica() are basically identity.
     self.assertEqual(_nested_value("1"), result)
     self.assertEqual(_nested_value("1"),
@@ -333,10 +288,9 @@ class RegroupAndSelectDeviceTest(test.TestCase):
     with ops.device(d):
       v = variable_scope.get_variable(
           name="v", initializer=1., use_resource=True)
-      device_map = values.ReplicaDeviceMap((d,))
-    mirrored = values.MirroredVariable(None, device_map, (v,),
+    mirrored = values.MirroredVariable(None, (v,),
                                        variable_scope.VariableAggregation.SUM)
-    result = values.regroup(device_map, (v,))
+    result = values.regroup((v,))
     self.assertIs(mirrored, result)
 
   def testNamedTuple(self):
@@ -356,7 +310,6 @@ class RegroupAndSelectDeviceTest(test.TestCase):
             scaffold=scaffold or Scaffold())
 
     with context.graph_mode(), ops.Graph().as_default():
-      devices = []
       created_estimator_specs = []
 
       for device_id in range(3):
@@ -364,25 +317,21 @@ class RegroupAndSelectDeviceTest(test.TestCase):
             mode=mode_keys.EstimatorModeKeys.TRAIN,
             loss=constant_op.constant(device_id / 2),
             train_op=array_ops.identity(constant_op.constant(device_id)))
-        devices.append(_device_str(device_id))
         created_estimator_specs.append(spec)
 
-      device_map = values.ReplicaDeviceMap(devices)
-      merged_estimator_spec = values.regroup(
-          device_map, created_estimator_specs)
+      merged_estimator_spec = values.regroup(created_estimator_specs)
 
       self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
       self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
                        merged_estimator_spec.mode)
       for device_id in range(3):
-        d = _device_str(device_id)
         self.assertEqual(created_estimator_specs[device_id].loss,
-                         merged_estimator_spec.loss.get(d))
+                         merged_estimator_spec.loss.values[device_id])
         self.assertEqual(created_estimator_specs[device_id].train_op,
-                         merged_estimator_spec.train_op.get(d))
+                         merged_estimator_spec.train_op.values[device_id])
         # Scaffold is populated by `EstimatorSpec.__new__`.
         self.assertEqual(created_estimator_specs[device_id].scaffold,
-                         merged_estimator_spec.scaffold.get(d))
+                         merged_estimator_spec.scaffold.values[device_id])
         self.assertIsInstance(created_estimator_specs[device_id].scaffold,
                               Scaffold)
         # Also test that we can undo the merge using select_replica()
@@ -401,28 +350,26 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
     if context.num_gpus() < 1 and context.executing_eagerly():
       self.skipTest("A GPU is not available for this test in eager mode.")
 
-    v, _, mirrored = _make_mirrored()
-
-    self.assertEqual(v[0].name, mirrored.name)
-    self.assertEqual(v[0].dtype, mirrored.dtype)
-    self.assertEqual(v[0].shape, mirrored.shape)
+    mirrored = _make_mirrored()
+    v = mirrored.values[0]
+    self.assertEqual(v.name, mirrored.name)
+    self.assertEqual(v.dtype, mirrored.dtype)
+    self.assertEqual(v.shape, mirrored.shape)
 
   @test_util.run_in_graph_and_eager_modes(config=config)
   def testVariableOnAnotherDevice(self):
     v = variable_scope.get_variable(
         name="v", initializer=[1.], use_resource=True)
-    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
-    mirrored = values.MirroredVariable(None, device_map, (v,),
-                                       variable_scope.VariableAggregation.MEAN)
+    mirrored = values.MirroredVariable(
+        None, (v,), variable_scope.VariableAggregation.MEAN)
 
     self.assertEqual(v.name, mirrored.name)
     self.assertEqual(v.dtype, mirrored.dtype)
     self.assertEqual(v.shape, mirrored.shape)
 
-  def _assign_mirrored(self, devices, v, new):
-    for d, var, n in zip(devices, v, new):
-      with ops.device(d):
-        self.evaluate(var.assign(n))
+  def _assign_mirrored(self, v, new):
+    for var, n in zip(v.values, new):
+      self.evaluate(var.assign(n))
 
   def _save_return_saver(self, sess, var):
     saver = saver_lib.Saver(var_list=[var])
@@ -445,17 +392,17 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
       self.skipTest("A GPU is not available for this test in eager mode.")
 
     with self.cached_session(config=self.config) as sess:
-      v, device_map, mirrored = _make_mirrored()
-      devices = device_map.all_devices
+      mirrored = _make_mirrored()
+      v = mirrored.values
 
       # Overwrite the initial values.
-      self._assign_mirrored(devices, v, [3., 4.])
+      self._assign_mirrored(mirrored, [3., 4.])
 
       # Saves the current value of v[0], 3.
       save_path, saver = self._save_return_saver(sess, mirrored)
 
       # Change the values between save and restore.
-      self._assign_mirrored(devices, v, [5., 6.])
+      self._assign_mirrored(mirrored, [5., 6.])
 
       # Restores the saved value of 3. to both variables.
       saver.restore(sess, save_path)
@@ -464,17 +411,16 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
   def _save_mirrored(self):
     """Save variables with mirroring, returns save_path."""
     with self.session(graph=ops.Graph()) as sess:
-      v, device_map, mirrored = _make_mirrored()
-      devices = device_map.all_devices
+      mirrored = _make_mirrored()
 
       # Overwrite the initial values.
-      self._assign_mirrored(devices, v, [3., 4.])
+      self._assign_mirrored(mirrored, [3., 4.])
 
       # Saves the current value of v[0], 3.
       save_path = self._save(sess, mirrored)
 
       # Change the values between save and restore.
-      self._assign_mirrored(devices, v, [5., 6.])
+      self._assign_mirrored(mirrored, [5., 6.])
     return save_path
 
   def _save_normal(self):
@@ -510,11 +456,11 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
   def _restore_mirrored(self, save_path):
     """Restore to variables with mirroring in a fresh graph."""
     with self.session(graph=ops.Graph()) as sess:
-      v, device_map, mirrored = _make_mirrored()
-      devices = device_map.all_devices
+      mirrored = _make_mirrored()
+      v = mirrored.values
 
       # Overwrite the initial values.
-      self._assign_mirrored(devices, v, [7., 8.])
+      self._assign_mirrored(mirrored, [7., 8.])
 
       # Restores the saved value of 3. to both variables.
       saver = saver_lib.Saver(var_list=[mirrored])
@@ -572,8 +518,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
         v = variable_scope.get_variable(
             name="v", initializer=1., use_resource=True)
       mirrored = values.MirroredVariable(
-          distribution, values.ReplicaDeviceMap(("/device:GPU:0",)), (v,),
-          variable_scope.VariableAggregation.MEAN)
+          distribution, (v,), variable_scope.VariableAggregation.MEAN)
       sess.run(variables_lib.global_variables_initializer())
       sess.run({"complicated": mirrored})
 
@@ -744,7 +689,6 @@ def _make_replica_local(method, strategy=None):
   else:
     devices = strategy.extended.worker_devices
 
-  device_map = values.ReplicaDeviceMap(devices)
   v = []
   for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
     with ops.device(d):
@@ -755,7 +699,7 @@ def _make_replica_local(method, strategy=None):
     var_cls = values.TPUSyncOnReadVariable
   else:
     var_cls = values.SyncOnReadVariable
-  replica_local = var_cls(strategy, device_map, v, method)
+  replica_local = var_cls(strategy, v, method)
   return v, replica_local
 
 
@@ -777,20 +721,6 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
     self.assertEqual(variable_scope.VariableAggregation.SUM,
                      replica_local.aggregation)
 
-  @test_util.run_in_graph_and_eager_modes(config=config)
-  def testVariableOnAnotherDevice(self):
-    v = variable_scope.get_variable(
-        name="v", initializer=[1.], use_resource=True)
-    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
-    replica_local = values.SyncOnReadVariable(
-        None, device_map, (v,), variable_scope.VariableAggregation.MEAN)
-
-    self.assertEqual(v.name, replica_local.name)
-    self.assertEqual(v.dtype, replica_local.dtype)
-    self.assertEqual(v.shape, replica_local.shape)
-    self.assertEqual(variable_scope.VariableAggregation.MEAN,
-                     replica_local.aggregation)
-
   def testTensorConversion(self):
     with context.graph_mode():
       _, replica_local = _make_replica_local(
@@ -812,9 +742,8 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
 
     v = variable_scope.get_variable(
         name="v", initializer=[1.], use_resource=True)
-    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
     replica_local = values.SyncOnReadVariable(
-        None, device_map, (v,), variable_scope.VariableAggregation.MEAN)
+        None, (v,), variable_scope.VariableAggregation.MEAN)
     self.assertEqual(2., self.evaluate(add1(replica_local)))
 
 
@@ -1171,6 +1100,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
     vals = self.evaluate(v[0].values)
     self.assertAllEqual(vals[0], vals[1])
 
+
 class MirroredTest(test.TestCase):
 
   def testAddOp(self):
@@ -1191,49 +1121,39 @@ class MirroredTest(test.TestCase):
 class PerReplicaTest(test.TestCase, parameterized.TestCase):
 
   def testTypeSpec(self):
-    device_map = values.SingleDeviceMap("CPU")
     vals = (constant_op.constant(1.),)
-    per_replica = values.PerReplica(device_map, vals)
+    per_replica = values.PerReplica(vals)
 
     spec = per_replica._type_spec
     self.assertEqual(spec._value_specs,
                      (tensor_spec.TensorSpec([], dtypes.float32),))
-    self.assertEqual(spec._device_map, per_replica.device_map)
-    self.assertEqual(spec._logical_device, per_replica.logical_device)
 
   def testTypeSpecRoundTrip(self):
-    device_map = values.SingleDeviceMap("CPU")
     vals = (constant_op.constant(1.),)
-    per_replica = values.PerReplica(device_map, vals)
+    per_replica = values.PerReplica(vals)
 
     spec = per_replica._type_spec
     tensor_list = spec._to_components(per_replica)
     reconstructed = spec._from_components(tensor_list)
 
-    self.assertEqual(per_replica.device_map, reconstructed.device_map)
-    self.assertEqual(per_replica.logical_device, reconstructed.logical_device)
     self.assertAllEqual(per_replica.values, reconstructed.values)
 
   def testTypeSpecNest(self):
-    device_map = values.ReplicaDeviceMap(["CPU:0", "CPU:1"])
     vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
-    per_replica = values.PerReplica(device_map, vals)
+    per_replica = values.PerReplica(vals)
 
     # Note: nest.map_structutre exercises nest.flatten and
     # nest.pack_sequence_as.
-    result = nest.map_structure(lambda t: t + 10, per_replica,
-                                expand_composites=True)
+    result = nest.map_structure(
+        lambda t: t + 10, per_replica, expand_composites=True)
 
-    self.assertEqual(per_replica.device_map, result.device_map)
-    self.assertEqual(per_replica.logical_device, result.logical_device)
     self.assertLen(result.values, 2)
     self.assertAllEqual(result.values[0], 11.)
     self.assertAllEqual(result.values[1], [15., 16.0])
 
   @test_util.run_in_graph_and_eager_modes
   def testIsGraphTensor(self):
-    per_replica = values.PerReplica(values.SingleDeviceMap("CPU"),
-                                    (constant_op.constant(1.),))
+    per_replica = values.PerReplica((constant_op.constant(1.),))
     for t in nest.flatten(per_replica, expand_composites=True):
       self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
 
@@ -1245,8 +1165,7 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
       traces.append(None)  # Only happens on trace.
       return x
 
-    per_replica = values.PerReplica(
-        values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
+    per_replica = values.PerReplica((constant_op.constant(1.),))
 
     # Trace once.
     f(per_replica)
@@ -1262,14 +1181,11 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
       output = f(per_replica)
       self.assertIsInstance(output, values.PerReplica)
       self.assertAllEqual(output._values, per_replica._values)
-      self.assertAllEqual(output._device_map, per_replica._device_map)
-      self.assertAllEqual(output._logical_device, per_replica._logical_device)
       self.assertEmpty(traces)  # Make sure we're not re-tracing `f`.
 
   def testFunctionCanReturnPerReplica(self):
     f = def_function.function(lambda x: x)
-    x = values.PerReplica(
-        values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
+    x = values.PerReplica((constant_op.constant(1.),))
     y = f(x)
     self.assertIsNot(x, y)
     nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
@@ -1277,40 +1193,32 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testCondWithTensorValues(self):
-    device_map = values.SingleDeviceMap("CPU")
-    per_replica_1 = values.PerReplica(device_map, (constant_op.constant("a"),))
-    per_replica_2 = values.PerReplica(device_map,
-                                      (constant_op.constant(["b", "c"]),))
+    per_replica_1 = values.PerReplica((constant_op.constant("a"),))
+    per_replica_2 = values.PerReplica((constant_op.constant(["b", "c"]),))
     condition = array_ops.placeholder_with_default(True, [])
 
     result = control_flow_ops.cond(
         condition, lambda: per_replica_1, lambda: per_replica_2)
 
-    self.assertEqual(per_replica_1.device_map, result.device_map)
-    self.assertEqual(per_replica_1.logical_device, result.logical_device)
     self.assertLen(result.values, 1)
     self.assertAllEqual(result.values[0], "a")
 
   @test_util.run_in_graph_and_eager_modes
   def testCondWithValuesConvertibleToTensor(self):
-    device_map = values.SingleDeviceMap("CPU")
-    per_replica_1 = values.PerReplica(device_map, ("a",))
-    per_replica_2 = values.PerReplica(device_map, ("b",))
+    per_replica_1 = values.PerReplica(("a",))
+    per_replica_2 = values.PerReplica(("b",))
     condition = array_ops.placeholder_with_default(True, [])
 
     result = control_flow_ops.cond(
         condition, lambda: per_replica_1, lambda: per_replica_2)
 
-    self.assertEqual(per_replica_1.device_map, result.device_map)
-    self.assertEqual(per_replica_1.logical_device, result.logical_device)
     self.assertLen(result.values, 1)
     self.assertAllEqual(result.values[0], "a")
 
   @test_util.build_as_function_and_v1_graph
   def testCondWithValuesNotConvertibleToTensor(self):
-    device_map = values.SingleDeviceMap("CPU")
-    per_replica_1 = values.PerReplica(device_map, (set(["a"]),))
-    per_replica_2 = values.PerReplica(device_map, (set(["b", "c"]),))
+    per_replica_1 = values.PerReplica(({"a"},))
+    per_replica_2 = values.PerReplica(({"b", "c"},))
     condition = array_ops.placeholder(dtypes.bool, [])
 
     with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
@@ -1318,88 +1226,5 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
           condition, lambda: per_replica_1, lambda: per_replica_2)
 
 
-class WorkerDeviceMapTest(test.TestCase, parameterized.TestCase):
-
-  class ReplicaContext(object):
-
-    def __init__(self, replica_id_in_sync_group):
-      self.replica_id_in_sync_group = replica_id_in_sync_group
-
-  def testBasic(self):
-    devices = [
-        "/job:worker/replica:0/task:0/device:CPU:0",
-        "/job:worker/replica:0/task:2/device:CPU:0"
-    ]
-    device_map = values.WorkerDeviceMap(devices, 1)
-    self.assertAllEqual(devices, device_map.all_devices)
-
-    # pylint:disable=pointless-statement
-    with self.assertRaisesWithPredicateMatch(
-        ValueError, "`WorkerDeviceMap` is not indexed by replicas"):
-      device_map.devices_by_replica
-
-    self.assertEqual(1, device_map.num_logical_devices)
-
-    self.assertEqual(2, device_map.num_replicas_in_graph)
-
-    self.assertEqual(0, device_map.logical_device_from_values(["a", "b"]))
-
-    self.assertAllEqual(devices, device_map.logical_to_actual_devices(0))
-
-    replica_context = WorkerDeviceMapTest.ReplicaContext(1)
-    self.assertEqual(
-        "b", device_map.select_for_current_replica(["a", "b"], replica_context))
-
-    with self.assertRaisesWithPredicateMatch(
-        ValueError, "`WorkerDeviceMap` not indexed by replicas"):
-      device_map.replica_for_device(devices[1])
-
-    self.assertEqual("b", device_map.select_for_device(["a", "b"], devices[1]))
-
-    with self.assertRaisesWithPredicateMatch(
-        ValueError, "WorkerDeviceMap not indexed by replicas"):
-      device_map.is_device_in_replica(devices[1], 1)
-
-    self.assertEqual(
-        "WorkerDeviceMap(('/job:worker/replica:0/task:0/device:CPU:0', "
-        "'/job:worker/replica:0/task:2/device:CPU:0'), "
-        "num_replicas_per_worker=1)", repr(device_map))
-
-  def testMultipleReplicasPerWorker(self):
-    devices = [
-        "/job:worker/replica:0/task:0/device:CPU:0",
-        "/job:worker/replica:0/task:2/device:CPU:0"
-    ]
-    device_map = values.WorkerDeviceMap(devices, 2)
-
-    replica_context = WorkerDeviceMapTest.ReplicaContext(3)
-    self.assertEqual(
-        "b", device_map.select_for_current_replica(["a", "b"], replica_context))
-
-  @combinations.generate(
-      combinations.combine(
-          distribution=[
-              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-              strategy_combinations.tpu_strategy,
-          ],
-          mode=["graph", "eager"]))
-  def testExperimentalLocalResultsOrder(self, distribution):
-    # Create 2 devices in the device map, where the alphabetical order and the
-    # actual order of devices are different.
-    device_map = values.ReplicaDeviceMap(["CPU:2", "CPU:10"])
-    vals = (
-        constant_op.constant(1.),
-        constant_op.constant([5., 6.0]),
-    )
-    per_replica = values.PerReplica(device_map, vals)
-    results = self.evaluate(
-        distribution.experimental_local_results(per_replica))
-
-    # We expect the outputs order the same as the inputs order.
-    self.assertLen(results, 2)
-    self.assertAllEqual(1.0, results[0])
-    self.assertAllEqual([5., 6.], results[1])
-
-
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/keras/distribute/keras_utils_test.py b/tensorflow/python/keras/distribute/keras_utils_test.py
index bf328e447c1..3529935dd51 100644
--- a/tensorflow/python/keras/distribute/keras_utils_test.py
+++ b/tensorflow/python/keras/distribute/keras_utils_test.py
@@ -197,9 +197,8 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
     with self.cached_session():
       a = constant_op.constant([1, 2], shape=(1, 2))
       b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2))
-      device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0'))
-      x = values.DistributedValues(device_map, (a, b))
-      y = values.DistributedValues(device_map, (a, a))
+      x = values.DistributedValues((a, b))
+      y = values.DistributedValues((a, a))
       # Removed device and input tensor shape details from the error message
       # since the order of the device and the corresponding input tensor shape
       # is not deterministic over different runs.
@@ -222,9 +221,8 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
     with self.cached_session():
       a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
       b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
-      device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0'))
-      x = values.DistributedValues(device_map, (a, b))
-      y = values.DistributedValues(device_map, (a, a))
+      x = values.DistributedValues((a, b))
+      y = values.DistributedValues((a, a))
       # Removed device and input tensor dtype details from the error message
       # since the order of the device and the corresponding input tensor dtype
       # is not deterministic over different runs.
diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py
index 71594de1058..62b777b8188 100644
--- a/tensorflow/python/module/module_test.py
+++ b/tensorflow/python/module/module_test.py
@@ -247,13 +247,10 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
     self.assertEqual(len(m.child.child.trainable_variables), 0)
 
   def test_supports_distributed_variables(self):
-    device_map = distributed_values.SingleDeviceMap("/CPU:0")
     mirrored = distributed_values.MirroredVariable(
-        None, device_map, [variables.Variable(1.)],
-        variables.VariableAggregation.SUM)
+        None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
     tpu = distributed_values.TPUMirroredVariable(
         strategy=None,
-        device_map=device_map,
         values=[variables.Variable(42.)],
         aggregation=None)
     aggregating = distributed_values.AggregatingVariable(
diff --git a/tensorflow/python/ops/stateful_random_ops_test.py b/tensorflow/python/ops/stateful_random_ops_test.py
index 499698b7d57..b68753617d6 100644
--- a/tensorflow/python/ops/stateful_random_ops_test.py
+++ b/tensorflow/python/ops/stateful_random_ops_test.py
@@ -727,9 +727,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
     devices = ["cpu:0", "cpu:1"]
     strat = MirroredStrategy(devices=devices)
     # Use `PerReplica` to specify which `gen` is sent to which replica
-    gens = dist_values.PerReplica(
-        device_map=dist_values.ReplicaDeviceMap(devices),
-        values=[[g] for g in gens])
+    gens = dist_values.PerReplica([[g] for g in gens])
     with strat.scope():
       def f(gen):
         t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py
index 2af31a9dd58..f6835fce76d 100644
--- a/tensorflow/python/tpu/tpu.py
+++ b/tensorflow/python/tpu/tpu.py
@@ -259,11 +259,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
     self._pivot = pivot
     self._replicated_vars = {}
 
-  def get_replicated_var_handle(self,
-                                name,
-                                vars_,
-                                device_map=None,
-                                is_mirrored=False):
+  def get_replicated_var_handle(self, name, vars_, is_mirrored=False):
     """Returns a variable handle for replicated TPU variable 'var'.
 
     This is a method used by an experimental replicated variable implementation
@@ -272,8 +268,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
     Args:
       name: The common name of the variable.
       vars_: The replicated TPU variables.
-      device_map: The DeviceMap used to create the variables if it is a
-        TPUMirroredVariable.
       is_mirrored: Whether the variables are mirrored, which guarantees the
         values in each replica are always the same.
 
@@ -287,15 +281,20 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
     if handle is not None:
       return handle
 
-    replicated_vars = []
-    if device_assignment is not None and device_map is not None:
-      job_name = pydev.DeviceSpec.from_string(device_map.all_devices[0]).job
+    if device_assignment is not None:
+      job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
+
+      tpu_devices = set()
       for replica_id in range(device_assignment.num_replicas):
-        tpu_device = device_assignment.tpu_device(
-            replica=replica_id, logical_core=0, job=job_name)
-        tpu_device = device_util.canonicalize(tpu_device)
-        replica = device_map.replica_for_device(tpu_device)
-        replicated_vars.append(vars_[replica])
+        for logical_core in range(device_assignment.num_cores_per_replica):
+          tpu_devices.add(
+              device_util.canonicalize(
+                  device_assignment.tpu_device(
+                      replica=replica_id,
+                      logical_core=logical_core,
+                      job=job_name)))
+
+      replicated_vars = [v for v in vars_ if v.device in tpu_devices]
     else:
       replicated_vars = vars_
 
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 3af76873051..f1a31d01dd4 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -768,7 +768,7 @@ class Optimizer(
       # pylint: enable=protected-access
       mirrored_slot = named_slots.get(key, None)
       if mirrored_slot is None: return None
-      return mirrored_slot.get(device=var.device)
+      return mirrored_slot._get_closest()  # pylint: disable=protected-access
 
     return named_slots.get(_var_key(var), None)
 

From 972488ce9520edece144a9d1acab8820c392fa08 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 03:27:38 -0800
Subject: [PATCH 134/279] Introduce Linkage attribute to the LLVM dialect

LLVM IR supports linkage on global objects such as global variables and
functions. Introduce the Linkage attribute into the LLVM dialect, backed by an
integer storage. Use this attribute on LLVM::GlobalOp and make it mandatory.
Implement parsing/printing of the attribute and conversion to LLVM IR.

See #277.

PiperOrigin-RevId: 283309328
Change-Id: Id8d54054e25df32916e711b31f2c8ef168b1bbe6
---
 .../include/mlir/Dialect/LLVMIR/LLVMDialect.h |  3 +-
 .../include/mlir/Dialect/LLVMIR/LLVMOps.td    | 33 ++++++-
 .../ConvertLaunchFuncToCudaCalls.cpp          |  5 +-
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  4 +-
 .../lib/Dialect/LLVMIR/IR/LLVMDialect.cpp     | 94 +++++++++++++++++--
 .../lib/Target/LLVMIR/ConvertFromLLVMIR.cpp   | 38 +++++++-
 .../lib/Target/LLVMIR/ModuleTranslation.cpp   | 39 +++++++-
 7 files changed, 196 insertions(+), 20 deletions(-)

diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index eb39537c03d..83c30e64b9f 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -195,7 +195,8 @@ private:
 /// global and use it to compute the address of the first character in the
 /// string (operations inserted at the builder insertion point).
 Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
-                          StringRef value, LLVM::LLVMDialect *llvmDialect);
+                          StringRef value, LLVM::Linkage linkage,
+                          LLVM::LLVMDialect *llvmDialect);
 
 } // end namespace LLVM
 } // end namespace mlir
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3d697b78374..324937a5c6d 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -467,8 +467,36 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
   let printer = [{ p << getOperationName(); }];
 }
 
+////////////////////////////////////////////////////////////////////////////////
 // Auxiliary operations (do not appear in LLVM IR but necessary for the dialect
 // to work correctly).
+////////////////////////////////////////////////////////////////////////////////
+
+// Linkage attribute is used on functions and globals. The order follows that of
+// https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to
+// visible names in the IR rather than to enum values names in llvm::GlobalValue
+// since the latter is easier to change.
+def LinkagePrivate             : I64EnumAttrCase<"Private", 0>;
+def LinkageInternal            : I64EnumAttrCase<"Internal", 1>;
+def LinkageAvailableExternally : I64EnumAttrCase<"AvailableExternally", 2>;
+def LinkageLinkonce            : I64EnumAttrCase<"Linkonce", 3>;
+def LinkageWeak                : I64EnumAttrCase<"Weak", 4>;
+def LinkageCommon              : I64EnumAttrCase<"Common", 5>;
+def LinkageAppending           : I64EnumAttrCase<"Appending", 6>;
+def LinkageExternWeak          : I64EnumAttrCase<"ExternWeak", 7>;
+def LinkageLinkonceODR         : I64EnumAttrCase<"LinkonceODR", 8>;
+def LinkageWeakODR             : I64EnumAttrCase<"WeakODR", 9>;
+def LinkageExternal            : I64EnumAttrCase<"External", 10>;
+def Linkage : I64EnumAttr<
+    "Linkage",
+    "LLVM linkage types",
+    [LinkagePrivate, LinkageInternal, LinkageAvailableExternally,
+     LinkageLinkonce, LinkageWeak, LinkageCommon, LinkageAppending,
+     LinkageExternWeak, LinkageLinkonceODR, LinkageWeakODR, LinkageExternal]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
+
 def LLVM_AddressOfOp
     : LLVM_OneResultOp<"mlir.addressof">,
       Arguments<(ins FlatSymbolRefAttr:$global_name)> {
@@ -501,6 +529,7 @@ def LLVM_GlobalOp
                         [IsolatedFromAbove,
                          SingleBlockImplicitTerminator<"ReturnOp">, Symbol]>,
       Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name,
+                 Linkage:$linkage,
                  OptionalAttr:$value,
                  DefaultValuedAttr:$addr_space)> {
   let summary = "LLVM dialect global.";
@@ -522,8 +551,8 @@ def LLVM_GlobalOp
 
   let builders = [
     OpBuilder<"Builder *builder, OperationState &result, LLVMType type, "
-              "bool isConstant, StringRef name, Attribute value, "
-              "ArrayRef attrs = {}">
+              "bool isConstant, Linkage linkage, StringRef name, "
+              "Attribute value, ArrayRef attrs = {}">
   ];
 
   let extraClassDeclaration = [{
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index 9d8c8942051..f342083bee7 100644
--- a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -320,7 +320,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
   std::string globalName = llvm::formatv("{0}_kernel_name", name);
   return LLVM::createGlobalString(
       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
-      llvmDialect);
+      LLVM::Linkage::Internal, llvmDialect);
 }
 
 // Emits LLVM IR to launch a kernel function. Expects the module that contains
@@ -368,7 +368,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
   SmallString<128> nameBuffer(*kernelModule.getName());
   nameBuffer.append(kCubinStorageSuffix);
   Value *data = LLVM::createGlobalString(
-      loc, builder, nameBuffer.str(), cubinAttr.getValue(), getLLVMDialect());
+      loc, builder, nameBuffer.str(), cubinAttr.getValue(),
+      LLVM::Linkage::Internal, getLLVMDialect());
 
   // Emit the load module call to load the module data. Error checking is done
   // in the called helper function.
diff --git a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index f56508dfeba..54dd18e7492 100644
--- a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -387,8 +387,8 @@ private:
         builder.getNamedAttr("addr_space", builder.getI32IntegerAttr(3));
     auto globalOp = builder.create(
         loc, arrayType.cast(),
-        /*isConstant=*/false, name, /*value=*/Attribute(),
-        llvm::makeArrayRef(addrSpace));
+        /*isConstant=*/false, LLVM::Linkage::Internal, name,
+        /*value=*/Attribute(), llvm::makeArrayRef(addrSpace));
 
     return rewriter.create(loc, globalOp);
   }
diff --git a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 66a9bc0ae9f..a8c676ff696 100644
--- a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -863,8 +863,8 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
 //===----------------------------------------------------------------------===//
 
 void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
-                     bool isConstant, StringRef name, Attribute value,
-                     ArrayRef attrs) {
+                     bool isConstant, Linkage linkage, StringRef name,
+                     Attribute value, ArrayRef attrs) {
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder->getStringAttr(name));
   result.addAttribute("type", TypeAttr::get(type));
@@ -872,12 +872,56 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
     result.addAttribute("constant", builder->getUnitAttr());
   if (value)
     result.addAttribute("value", value);
+  result.addAttribute(
+      "linkage", builder->getI64IntegerAttr(static_cast(linkage)));
   result.attributes.append(attrs.begin(), attrs.end());
   result.addRegion();
 }
 
+// Prints the keyword for the linkage type using the printer.
+static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) {
+  switch (linkage) {
+  case LLVM::Linkage::Private:
+    p << "private";
+    return;
+  case LLVM::Linkage::Internal:
+    p << "internal";
+    return;
+  case LLVM::Linkage::AvailableExternally:
+    p << "available_externally";
+    return;
+  case LLVM::Linkage::Linkonce:
+    p << "linkonce";
+    return;
+  case LLVM::Linkage::Weak:
+    p << "weak";
+    return;
+  case LLVM::Linkage::Common:
+    p << "common";
+    return;
+  case LLVM::Linkage::Appending:
+    p << "appending";
+    return;
+  case LLVM::Linkage::ExternWeak:
+    p << "extern_weak";
+    return;
+  case LLVM::Linkage::LinkonceODR:
+    p << "linkonce_odr";
+    return;
+  case LLVM::Linkage::WeakODR:
+    p << "weak_odr";
+    return;
+  case LLVM::Linkage::External:
+    p << "external";
+    return;
+  }
+  llvm_unreachable("unknown linkage type");
+}
+
 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
   p << op.getOperationName() << ' ';
+  printLinkage(p, op.linkage());
+  p << ' ';
   if (op.constant())
     p << "constant ";
   p.printSymbolName(op.sym_name());
@@ -885,8 +929,9 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
   if (auto value = op.getValueOrNull())
     p.printAttribute(value);
   p << ')';
-  p.printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(),
-                                          "type", "constant", "value"});
+  p.printOptionalAttrDict(op.getAttrs(),
+                          {SymbolTable::getSymbolAttrName(), "type", "constant",
+                           "value", "linkage"});
 
   // Print the trailing type unless it's a string global.
   if (op.getValueOrNull().dyn_cast_or_null())
@@ -899,12 +944,45 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
 }
 
-//  ::= `llvm.mlir.global` `constant`? `@` identifier
-//                 `(` attribute? `)` attribute-list? (`:` type)? region?
+// Parses one of the keywords provided in the list `keywords` and returns the
+// position of the parsed keyword in the list. If none of the keywords from the
+// list is parsed, returns -1.
+static int parseOptionalKeywordAlternative(OpAsmParser &parser,
+                                           ArrayRef keywords) {
+  for (auto en : llvm::enumerate(keywords)) {
+    if (succeeded(parser.parseOptionalKeyword(en.value())))
+      return en.index();
+  }
+  return -1;
+}
+
+// Parses one of the linkage keywords and, if succeeded, appends the "linkage"
+// integer attribute with the corresponding value to `result`.
+//
+// linkage ::= `private` | `internal` | `available_externally` | `linkonce`
+//           | `weak` | `common` | `appending` | `extern_weak`
+//           | `linkonce_odr` | `weak_odr` | `external
+static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser,
+                                               OperationState &result) {
+  int index = parseOptionalKeywordAlternative(
+      parser, {"private", "internal", "available_externally", "linkonce",
+               "weak", "common", "appending", "extern_weak", "linkonce_odr",
+               "weak_odr", "external"});
+  if (index == -1)
+    return failure();
+  result.addAttribute("linkage", parser.getBuilder().getI64IntegerAttr(index));
+  return success();
+}
+
+// operation ::= `llvm.mlir.global` linkage `constant`? `@` identifier
+//               `(` attribute? `)` attribute-list? (`:` type)? region?
 //
 // The type can be omitted for string attributes, in which case it will be
 // inferred from the value of the string as [strlen(value) x i8].
 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
+  if (failed(parseOptionalLinkageKeyword(parser, result)))
+    return parser.emitError(parser.getCurrentLocation(), "expected linkage");
+
   if (succeeded(parser.parseOptionalKeyword("constant")))
     result.addAttribute("constant", parser.getBuilder().getUnitAttr());
 
@@ -1489,6 +1567,7 @@ LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
 
 Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
                                       StringRef name, StringRef value,
+                                      LLVM::Linkage linkage,
                                       LLVM::LLVMDialect *llvmDialect) {
   assert(builder.getInsertionBlock() &&
          builder.getInsertionBlock()->getParentOp() &&
@@ -1502,7 +1581,8 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
   auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
                                          value.size());
   auto global = moduleBuilder.create(
-      loc, type, /*isConstant=*/true, name, builder.getStringAttr(value));
+      loc, type, /*isConstant=*/true, linkage, name,
+      builder.getStringAttr(value));
 
   // Get the pointer to the first character in the global string.
   Value *globalPtr = builder.create(loc, global);
diff --git a/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index fd4e4134d8b..6cf975bcce2 100644
--- a/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -215,6 +215,37 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
   return Attribute();
 }
 
+/// Converts LLVM global variable linkage type into the LLVM dialect predicate.
+static LLVM::Linkage
+processLinkage(llvm::GlobalVariable::LinkageTypes linkage) {
+  switch (linkage) {
+  case llvm::GlobalValue::PrivateLinkage:
+    return LLVM::Linkage::Private;
+  case llvm::GlobalValue::InternalLinkage:
+    return LLVM::Linkage::Internal;
+  case llvm::GlobalValue::AvailableExternallyLinkage:
+    return LLVM::Linkage::AvailableExternally;
+  case llvm::GlobalValue::LinkOnceAnyLinkage:
+    return LLVM::Linkage::Linkonce;
+  case llvm::GlobalValue::WeakAnyLinkage:
+    return LLVM::Linkage::Weak;
+  case llvm::GlobalValue::CommonLinkage:
+    return LLVM::Linkage::Common;
+  case llvm::GlobalValue::AppendingLinkage:
+    return LLVM::Linkage::Appending;
+  case llvm::GlobalValue::ExternalWeakLinkage:
+    return LLVM::Linkage::ExternWeak;
+  case llvm::GlobalValue::LinkOnceODRLinkage:
+    return LLVM::Linkage::LinkonceODR;
+  case llvm::GlobalValue::WeakODRLinkage:
+    return LLVM::Linkage::WeakODR;
+  case llvm::GlobalValue::ExternalLinkage:
+    return LLVM::Linkage::External;
+  }
+
+  llvm_unreachable("unhandled linkage type");
+}
+
 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
   auto it = globals.find(GV);
   if (it != globals.end())
@@ -224,9 +255,10 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
   Attribute valueAttr;
   if (GV->hasInitializer())
     valueAttr = getConstantAsAttr(GV->getInitializer());
-  GlobalOp op = b.create(UnknownLoc::get(context),
-                                   processType(GV->getValueType()),
-                                   GV->isConstant(), GV->getName(), valueAttr);
+  GlobalOp op = b.create(
+      UnknownLoc::get(context), processType(GV->getValueType()),
+      GV->isConstant(), processLinkage(GV->getLinkage()), GV->getName(),
+      valueAttr);
   if (GV->hasInitializer() && !valueAttr) {
     Region &r = op.getInitializerRegion();
     currentEntryBlock = b.createBlock(&r);
diff --git a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 7f3ce5a738f..f985fed3991 100644
--- a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -279,6 +279,35 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
   return success();
 }
 
+// Convert the LLVM dialect linkage type to LLVM IR linkage type.
+llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) {
+  switch (linkage) {
+  case LLVM::Linkage::Private:
+    return llvm::GlobalValue::PrivateLinkage;
+  case LLVM::Linkage::Internal:
+    return llvm::GlobalValue::InternalLinkage;
+  case LLVM::Linkage::AvailableExternally:
+    return llvm::GlobalValue::AvailableExternallyLinkage;
+  case LLVM::Linkage::Linkonce:
+    return llvm::GlobalValue::LinkOnceAnyLinkage;
+  case LLVM::Linkage::Weak:
+    return llvm::GlobalValue::WeakAnyLinkage;
+  case LLVM::Linkage::Common:
+    return llvm::GlobalValue::CommonLinkage;
+  case LLVM::Linkage::Appending:
+    return llvm::GlobalValue::AppendingLinkage;
+  case LLVM::Linkage::ExternWeak:
+    return llvm::GlobalValue::ExternalWeakLinkage;
+  case LLVM::Linkage::LinkonceODR:
+    return llvm::GlobalValue::LinkOnceODRLinkage;
+  case LLVM::Linkage::WeakODR:
+    return llvm::GlobalValue::WeakODRLinkage;
+  case LLVM::Linkage::External:
+    return llvm::GlobalValue::ExternalLinkage;
+  }
+  llvm_unreachable("unknown linkage type");
+}
+
 // Create named global variables that correspond to llvm.mlir.global
 // definitions.
 void ModuleTranslation::convertGlobals() {
@@ -308,11 +337,15 @@ void ModuleTranslation::convertGlobals() {
       cst = cast(valueMapping.lookup(ret.getOperand(0)));
     }
 
+    auto linkage = convertLinkageType(op.linkage());
+    bool anyExternalLinkage =
+        (linkage == llvm::GlobalVariable::ExternalLinkage ||
+         linkage == llvm::GlobalVariable::ExternalWeakLinkage);
     auto addrSpace = op.addr_space().getLimitedValue();
     auto *var = new llvm::GlobalVariable(
-        *llvmModule, type, op.constant(), llvm::GlobalValue::InternalLinkage,
-        cst, op.sym_name(), /*InsertBefore=*/nullptr,
-        llvm::GlobalValue::NotThreadLocal, addrSpace);
+        *llvmModule, type, op.constant(), linkage,
+        anyExternalLinkage ? nullptr : cst, op.sym_name(),
+        /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, addrSpace);
 
     globalsMapping.try_emplace(op, var);
   }

From 4cd9b16bdc5bf3ec906dce0cff04437b1c72d017 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 03:38:34 -0800
Subject: [PATCH 135/279] Remove distribution strategy device map code.

PiperOrigin-RevId: 283310504
Change-Id: Ie7af08bdc52c660d61ab527d72f616c088ad1480
---
 .../collective_all_reduce_strategy.py         |   5 +-
 .../python/distribute/cross_device_ops.py     |  89 +--
 .../distribute/cross_device_ops_test.py       |  17 +-
 .../distribute/cross_device_utils_test.py     |   3 +-
 .../python/distribute/distribute_lib_test.py  |   5 +-
 tensorflow/python/distribute/input_lib.py     |  45 +-
 .../python/distribute/input_lib_test.py       |  11 +-
 .../distribute/mirrored_function_strategy.py  |   8 +-
 .../python/distribute/mirrored_strategy.py    |  87 +--
 .../distribute/mirrored_strategy_test.py      |  33 +-
 .../distribute/mirrored_variable_test.py      |  53 +-
 .../python/distribute/one_device_strategy.py  |   4 +-
 .../distribute/parameter_server_strategy.py   |  28 +-
 tensorflow/python/distribute/tpu_strategy.py  |  35 +-
 tensorflow/python/distribute/values.py        | 525 +++++++++++++++---
 tensorflow/python/distribute/values_test.py   | 427 +++++++++-----
 .../keras/distribute/keras_utils_test.py      |  10 +-
 tensorflow/python/module/module_test.py       |   5 +-
 .../python/ops/stateful_random_ops_test.py    |   4 +-
 tensorflow/python/tpu/tpu.py                  |  29 +-
 tensorflow/python/training/optimizer.py       |   2 +-
 21 files changed, 1018 insertions(+), 407 deletions(-)

diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index 89d13c0777f..507e7779cfe 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -209,7 +209,6 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
         local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
       else:
         local_devices = ("/device:CPU:0",)
-
     self._worker_device = device_util.canonicalize("/device:CPU:0")
     self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
 
@@ -328,7 +327,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
     super(CollectiveAllReduceExtended, self)._initialize_single_worker(
         local_devices)
     self._input_workers = input_lib.InputWorkers(
-        [(self._worker_device, self.worker_devices)])
+        self._device_map, [(self._worker_device, self.worker_devices)])
 
     # Add a default device so that ops without specified devices will not end up
     # on other workers.
@@ -524,7 +523,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
       # replicas in which case `value` would be a single value or value could
       # be 0.
       return cross_device_ops_lib.reduce_non_distributed_value(
-          reduce_op, value, destinations, len(self.worker_devices))
+          reduce_op, self._device_map, value, destinations)
     return self._get_cross_device_ops().reduce(
         reduce_op, value, destinations=destinations)
 
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index 07aab81587a..9fc49df0ead 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -65,7 +65,10 @@ def validate_destinations(destinations):
           ops.Tensor,
           value_lib.AggregatingVariable,
           six.string_types,
-          value_lib.TPUMirroredVariable)):
+          value_lib.TPUMirroredVariable,
+          # LogicalDeviceSpec is only used internally, e.g. as a
+          # broadcast destination, never supplied by a user.
+          value_lib.LogicalDeviceSpec)):
     raise ValueError("destinations must be one of a `DistributedValues` object,"
                      " a tf.Variable object, or a device string.")
 
@@ -73,8 +76,7 @@ def validate_destinations(destinations):
     raise ValueError("destinations can not be empty")
 
 
-def reduce_non_distributed_value(
-    reduce_op, value, destinations, num_replicas_in_graph):
+def reduce_non_distributed_value(reduce_op, device_map, value, destinations):
   """Reduce a non-DistributedValue `value` to `destinations`."""
   if isinstance(value, value_lib.DistributedValues):
     raise ValueError("You are passing a `DistributedValue` to "
@@ -90,16 +92,15 @@ def reduce_non_distributed_value(
   # that value should be on all destinations.
   if reduce_op == reduce_util.ReduceOp.MEAN:
     return value
-  elif num_replicas_in_graph != 1:
-    # We do not support a reduce op of SUM if the value is the same across
-    # all replicas. We call this as part of assign functions for
-    # MirroredVariables and summing up identical values across replicas is not
-    # clearly defined.
+
+  validate_destinations(destinations)
+  # We do not support a reduce op of SUM if the value is the same across
+  # all replicas. We call this as part of assign functions for MirroredVariables
+  # and summing up identical values across replicas is not clearly defined.
+  if device_map.num_replicas_in_graph != 1:
     raise ValueError("A non-DistributedValues value %s cannot be reduced with "
                      "the given reduce op %s." % (value, reduce_op))
-  else:
-    validate_destinations(destinations)
-    return simple_broadcast(value, destinations)
+  return simple_broadcast(value, destinations)
 
 
 def _make_tensor_into_per_replica(input_tensor):
@@ -110,12 +111,16 @@ def _make_tensor_into_per_replica(input_tensor):
                      % (input_tensor,))
   if isinstance(input_tensor, value_lib.PerReplica):
     return input_tensor
-  elif hasattr(input_tensor, "device"):
-    return value_lib.PerReplica((input_tensor,))
-  else:
+
+  try:
+    device = input_tensor.device
+  except AttributeError:
     raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
                      "because it doesn't have device set.")
 
+  device_map = value_lib.SingleDeviceMap(device)
+  return value_lib.PerReplica(device_map, (input_tensor,))
+
 
 def _normalize_value_destination_pairs(value_destination_pairs):
   """Converts each tensor into a PerReplica object in the input list."""
@@ -156,11 +161,25 @@ def _validate_value_destination_pairs(value_destination_pairs):
 def get_devices_from(destinations):
   if isinstance(destinations, value_lib.DistributedValues):
     return destinations.devices
+  elif isinstance(destinations, value_lib.LogicalDeviceSpec):
+    return destinations.device_map.logical_to_actual_devices(
+        destinations.logical_device)
   elif isinstance(destinations, six.string_types):
     return (device_util.resolve(destinations),)
   return (device_util.resolve(destinations.device),)
 
 
+def get_device_map_from(destinations):
+  if isinstance(destinations, (value_lib.DistributedValues,
+                               value_lib.LogicalDeviceSpec)):
+    return destinations.device_map, destinations.logical_device
+  if isinstance(destinations, six.string_types):
+    device = device_util.resolve(destinations)
+  else:
+    device = destinations.device
+  return value_lib.SingleDeviceMap(device), 0
+
+
 def _devices_match(left, right):
   return set(get_devices_from(left)) == set(get_devices_from(right))
 
@@ -176,7 +195,8 @@ def _all_devices_match(value_destination_pairs):
 
 def simple_broadcast(value, destinations, always_mirrored=False):
   """Broadcast `value` to `destinations` using simple copies."""
-  devices = get_devices_from(destinations)
+  device_map, logical_device = get_device_map_from(destinations)
+  devices = device_map.logical_to_actual_devices(logical_device)
   if len(devices) == 1 and not always_mirrored:
     return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
         value, devices[0])
@@ -184,8 +204,10 @@ def simple_broadcast(value, destinations, always_mirrored=False):
     value_updates = []
     for d in devices:
       value_updates.append(
-          cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
-    return value_lib.regroup(value_updates, wrap_class=value_lib.Mirrored)
+          cross_device_utils.copy_tensor_or_indexed_slices_to_device(
+              value, d))
+    return value_lib.regroup(
+        device_map, value_updates, wrap_class=value_lib.Mirrored)
 
 
 def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
@@ -252,6 +274,7 @@ class CrossDeviceOps(object):
         per_replica_value.values) == 1 and _devices_match(
             per_replica_value, destinations):
       return value_lib.regroup(
+          per_replica_value.device_map,
           per_replica_value.values,
           wrap_class=value_lib.Mirrored)
 
@@ -296,7 +319,8 @@ class CrossDeviceOps(object):
         value_destination_pairs) and len(
             value_destination_pairs[0][0].values) == 1:
       return [
-          value_lib.regroup(v.values, wrap_class=value_lib.Mirrored)
+          value_lib.regroup(
+              v.device_map, v.values, wrap_class=value_lib.Mirrored)
           for v, _ in value_destination_pairs
       ]
 
@@ -474,7 +498,8 @@ def _ungroup_and_make_mirrored(grouped_reduced,
   Returns:
     a list of Mirrored objects.
   """
-  num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
+  device_map, _ = get_device_map_from(destinations)
+  num_replicas = device_map.num_replicas_in_graph * num_between_graph_workers
   index = [[] for _ in range(len(grouped_reduced[0]))]
   for per_replica_reduced in grouped_reduced:
     for i, (v, _) in enumerate(per_replica_reduced):
@@ -483,7 +508,10 @@ def _ungroup_and_make_mirrored(grouped_reduced,
           index[i].append(v / num_replicas)
       else:
         index[i].append(v)
-  return [value_lib.regroup(v, wrap_class=value_lib.Mirrored) for v in index]
+  return [
+      value_lib.regroup(device_map, v, wrap_class=value_lib.Mirrored)
+      for v in index
+  ]
 
 
 class _ConcatAndSplitPacker(object):
@@ -1008,33 +1036,32 @@ class CollectiveAllReduce(CrossDeviceOps):
 
   def reduce_implementation(self, reduce_op, per_replica_value, destinations):
     all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
-    devices = get_devices_from(destinations)
+    device_map, logical_device = get_device_map_from(destinations)
+    devices = device_map.logical_to_actual_devices(logical_device)
 
     if (isinstance(all_reduced, value_lib.Mirrored) and
-        (all_reduced.devices == devices)):
+        all_reduced.device_map is device_map and
+        all_reduced.logical_device == logical_device):
       return all_reduced
 
     # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
     # utility to access component for a particular device.
     if not isinstance(all_reduced, value_lib.Mirrored):
-      all_reduced = value_lib.Mirrored([all_reduced])
+      all_reduced = value_lib.Mirrored(
+          value_lib.SingleDeviceMap(all_reduced.device), [all_reduced])
 
-    # If we got this far, the destination devices do not match the all-reduce
-    # devices, so we must map from one to the other.
     index = []
-    # We must add these control dependencies, otherwise we can get deadlock.
     with ops.control_dependencies(all_reduced.values):
       for d in devices:
         with ops.device(d):
-          for v in all_reduced.values:
-            if v.device == d:
-              index.append(array_ops.identity(v))
-              break
+          if d in all_reduced.devices:
+            index.append(array_ops.identity(all_reduced.get(d)))
           else:
             # TODO(josh11b): Once we add support for model parallelism, get the
             # copy from the corresponding replica instead of the primary.
             index.append(array_ops.identity(all_reduced.primary))
-    return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
+
+    return value_lib.regroup(device_map, index, wrap_class=value_lib.Mirrored)
 
   def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
     all_devices_match = _all_devices_match(value_destination_pairs)
diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py
index 7af7c48e57d..0ec049ad4c1 100644
--- a/tensorflow/python/distribute/cross_device_ops_test.py
+++ b/tensorflow/python/distribute/cross_device_ops_test.py
@@ -66,7 +66,7 @@ def _make_per_replica(values, devices, regroup=False):
     with ops.device(d):
       placed_v = array_ops.identity(v)
     index.append(placed_v)
-  return value_lib.regroup(index)
+  return value_lib.regroup(value_lib.ReplicaDeviceMap(devices), index)
 
 
 # pylint: disable=g-doc-args,g-doc-return-or-yield
@@ -82,6 +82,7 @@ def _fake_mirrored(value, devices):
     with ops.device(d):
       values.append(array_ops.identity(value))
   return value_lib.regroup(
+      value_lib.ReplicaDeviceMap(devices),
       values,
       wrap_class=value_lib.Mirrored)
 
@@ -99,6 +100,7 @@ def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
   values = [_make_indexed_slices(values, indices, dense_shape, d)
             for d in devices]
   return value_lib.regroup(
+      value_lib.ReplicaDeviceMap(devices),
       values,
       wrap_class=value_lib.Mirrored)
 
@@ -125,7 +127,8 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
     else:
       if isinstance(left, value_lib.DistributedValues):
         self.assertEqual(set(left.devices), set(right.devices))
-        self._assert_values_equal(left.values, right.values)
+        self._assert_values_equal([left.get(d) for d in sorted(left.devices)],
+                                  [right.get(d) for d in sorted(right.devices)])
       else:
         self.assertEqual(
             device_util.resolve(left.device), device_util.resolve(right.device))
@@ -214,7 +217,8 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
     t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
     t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], dense_shape,
                               devices[1])
-    per_replica = value_lib.PerReplica((t0, t1))
+    per_replica = value_lib.PerReplica(
+        value_lib.ReplicaDeviceMap(devices), (t0, t1))
 
     if batch_reduce:
       result = cross_device_ops_instance.batch_reduce(
@@ -335,6 +339,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
           cross_device_ops_lib.choose_the_best(devices),
           cross_device_ops_lib.ReductionToOneDevice)
 
+
   @combinations.generate(combinations.combine(
       mode=["graph", "eager"],
       required_gpus=1))
@@ -342,7 +347,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
     devices = ["/cpu:0", "/gpu:0"]
     t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
     t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
-    per_replica = value_lib.PerReplica((t0, t1))
+    per_replica = value_lib.PerReplica(
+        value_lib.ReplicaDeviceMap(devices), (t0, t1))
     result = cross_device_ops_lib._simple_reduce(
         per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM)
 
@@ -642,7 +648,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
       indexed_slices.append(
           _make_indexed_slices(values[idx], indices[idx], dense_shape, d))
     if as_per_replica:
-      per_replica = value_lib.PerReplica(indexed_slices)
+      per_replica = value_lib.PerReplica(
+          value_lib.ReplicaDeviceMap(devices), indexed_slices)
       return per_replica
     else:
       return indexed_slices
diff --git a/tensorflow/python/distribute/cross_device_utils_test.py b/tensorflow/python/distribute/cross_device_utils_test.py
index 217883ea21b..16caad7615a 100644
--- a/tensorflow/python/distribute/cross_device_utils_test.py
+++ b/tensorflow/python/distribute/cross_device_utils_test.py
@@ -103,7 +103,8 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
         constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
     t1 = math_ops._as_indexed_slices(
         constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
-    per_replica = value_lib.PerReplica((t0, t1))
+    device_map = value_lib.ReplicaDeviceMap(("/gpu:0", "/cpu:0"))
+    per_replica = value_lib.PerReplica(device_map, (t0, t1))
     self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica))
 
   @combinations.generate(combinations.combine(
diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py
index fe97d37f417..fb8116d4ab2 100644
--- a/tensorflow/python/distribute/distribute_lib_test.py
+++ b/tensorflow/python/distribute/distribute_lib_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.distribute import input_lib
 from tensorflow.python.distribute import reduce_util
+from tensorflow.python.distribute import values
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -66,8 +67,10 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
 
   def __init__(self, distribute):
     super(_TestExtended, self).__init__(distribute)
+    device_map = values.ReplicaDeviceMap(["/device:CPU:0"])
     worker_device_pairs = [("", ["/device:CPU:0"])]
-    self._input_workers = input_lib.InputWorkers(worker_device_pairs)
+    self._input_workers = input_lib.InputWorkers(device_map,
+                                                 worker_device_pairs)
 
   def _call_for_each_replica(self, fn, args, kwargs):
     with _TestReplicaContext(
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index 2b92b7f7c22..f1f9a0e872d 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -130,16 +130,40 @@ def get_distributed_datasets_from_function(dataset_fn,
 class InputWorkers(object):
   """A 1-to-many mapping from input worker devices to compute devices."""
 
-  def __init__(self, worker_device_pairs):
+  def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
     """Initialize an `InputWorkers` object.
 
     Args:
+      device_map: A `DeviceMap` with the computation devices fed by the
+        input workers.
       worker_device_pairs: A sequence of pairs:
         `(input device, a tuple of compute devices fed by that input device)`.
+      logical_device: The logical device of `device_map` to feed.
     """
+    self._device_map = device_map
+    self._logical_device = logical_device
+    if worker_device_pairs is None:
+      devices = device_map.logical_to_actual_devices(logical_device)
+      worker_device_pairs = ((
+          device_util.canonicalize("/device:CPU:0", devices[0]),
+          devices),)
     self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
     self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
                               for _, f in worker_device_pairs)
+    flattened = tuple(d for l in self._fed_devices for d in l)
+    assert (flattened ==
+            device_map.logical_to_actual_devices(logical_device)), (
+                "flattened: %s logical device %d: %s" %
+                (flattened, logical_device,
+                 device_map.logical_to_actual_devices(logical_device)))
+
+  @property
+  def device_map(self):
+    return self._device_map
+
+  @property
+  def logical_device(self):
+    return self._logical_device
 
   @property
   def num_workers(self):
@@ -157,7 +181,8 @@ class InputWorkers(object):
     debug_repr = ",\n".join("  %d %s: %s" %
                             (i, devices[i], self._fed_devices[i])
                             for i in range(len(devices)))
-    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
+    return "%s:{\n%s\n  device_map: %s}" % (
+        self.__class__.__name__, debug_repr, self._device_map)
 
 
 def _get_next_as_optional(iterator, strategy, name=None):
@@ -188,9 +213,12 @@ def _get_next_as_optional(iterator, strategy, name=None):
   # TODO(b/131423105): we should be able to short-cut the all-reduce in some
   # cases.
   if getattr(strategy.extended, "_support_per_replica_values", True):
-    # Slight hack: `reduce` expects a `PerReplica`, so we pass it one, even
-    # though it doesn't actually have a value per replica.
-    worker_has_values = values.PerReplica(worker_has_values)
+    worker_has_values = values.PerReplica(
+        values.WorkerDeviceMap(
+            worker_devices,
+            num_replicas_per_worker=len(
+                strategy.extended._input_workers._input_worker_devices)),  # pylint: disable=protected-access
+        worker_has_values)
     global_has_value = strategy.reduce(
         reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
   else:
@@ -264,7 +292,7 @@ class DistributedIterator(object):
           # Make `replicas` a flat list of values across all replicas.
           replicas.extend(
               self._iterators[i].get_next_as_list_static_shapes(new_name))
-      return values.regroup(replicas)
+      return values.regroup(self._input_workers.device_map, replicas)
 
     out_of_range_replicas = []
     def out_of_range_fn(worker_index, device):
@@ -324,7 +352,7 @@ class DistributedIterator(object):
               dense_shape=dense_shape)
     replicas = nest.pack_sequence_as(replicas, flattened_replicas)
 
-    return values.regroup(replicas)
+    return values.regroup(self._input_workers.device_map, replicas)
 
   # We need a private initializer method for re-initializing multidevice
   # iterators when used with Keras training loops. If we don't reinitialize the
@@ -431,7 +459,8 @@ class _IterableInput(object):
       else:
         raise ValueError("Dataset iteration within a tf.function is"
                          " not supported for multiple workers.")
-      state = reduce_fn(state, values.regroup(data))
+      per_replica_data = values.regroup(self._input_workers.device_map, data)
+      state = reduce_fn(state, per_replica_data)
       has_data, data = _get_next_as_optional(iterator, self._strategy)
       return has_data, data, state
 
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index 433d18d36cb..96363053219 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -138,7 +138,8 @@ class DistributedIteratorTestBase(test.TestCase):
       self.skipTest("unsupported test combination.")
 
     devices = nest.flatten([ds for _, ds in worker_device_pairs])
-    input_workers = input_lib.InputWorkers(worker_device_pairs)
+    device_map = values.ReplicaDeviceMap(devices)
+    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
 
     if api_type == "wrap_into_iterator":
       iterator = self._wrap_iterator(
@@ -235,7 +236,9 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
     worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
     dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
 
-    input_workers = input_lib.InputWorkers(worker_device_pairs)
+    devices = nest.flatten([ds for _, ds in worker_device_pairs])
+    device_map = values.ReplicaDeviceMap(devices)
+    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
 
     dist_dataset = input_lib.get_distributed_dataset(
         dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
@@ -257,7 +260,9 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
           ]))
   def testDatasetV2IterError(self, distribution):
     worker_device_pairs = [("", ["/device:CPU:0"])]
-    input_workers = input_lib.InputWorkers(worker_device_pairs)
+    devices = nest.flatten([ds for _, ds in worker_device_pairs])
+    device_map = values.ReplicaDeviceMap(devices)
+    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
     dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
 
     dist_dataset = input_lib.get_distributed_dataset(
diff --git a/tensorflow/python/distribute/mirrored_function_strategy.py b/tensorflow/python/distribute/mirrored_function_strategy.py
index aa9ecfa1fc4..aa81aaabfe0 100644
--- a/tensorflow/python/distribute/mirrored_function_strategy.py
+++ b/tensorflow/python/distribute/mirrored_function_strategy.py
@@ -91,7 +91,8 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1):
     device_tuple = tuple(device_util.resolve(d) for d in devices)
     assert len(set(device_tuple)) == len(device_tuple), (
         "No duplicates allowed in `devices` argument: %s" % (devices,))
-    self._devices = device_tuple
+    self._device_map = values.ReplicaDeviceMap(device_tuple)
+
     self._retrace_functions_for_each_device = False
 
   def _call_for_each_replica(self, fn, args, kwargs):
@@ -115,7 +116,8 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1):
 
     try:
       with MirroredFunctionReplicaContext(self._container_strategy()):
-        for index, device in enumerate(self._devices):
+        for index, device in enumerate(
+            self._device_map.logical_to_actual_devices(0)):
           _replica_index.current = index
           with ops.device(device):
             if context.executing_eagerly():
@@ -132,7 +134,7 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1):
       _replica_index.graph_outside_run = None
       _replica_index.current = None
 
-    return values.regroup(return_values)
+    return values.regroup(self._device_map, return_values)
 
   def _local_results(self, val):
     if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index b45c52e9ad6..0fb8ae0aafb 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -89,12 +89,12 @@ class _RequestedStop(Exception):  # pylint: disable=g-bad-exception-name
 
 # TODO(yuefengz): maybe create a common class for those who need to call this
 # _call_for_each_replica.
-def _call_for_each_replica(distribution, devices, fn, args, kwargs):
+def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
   """Run `fn` in separate threads, once per replica/worker device.
 
   Args:
     distribution: the DistributionStrategy object.
-    devices: the devices to run `fn` on (logical device 0 for each replica).
+    device_map: the DeviceMap with the devices to run `fn` on.
     fn: function to run (will be run once per replica, each in its own thread).
     args: positional arguments for `fn`
     kwargs: keyword arguments for `fn`.
@@ -119,11 +119,11 @@ def _call_for_each_replica(distribution, devices, fn, args, kwargs):
 
   # TODO(isaprykin): Create these threads once instead of during every call.
   threads = []
-  for index in range(len(devices)):
+  for index in range(device_map.num_replicas_in_graph):
     variable_creator_fn = shared_variable_creator.make_fn(
         shared_variable_store, index)
     t = _MirroredReplicaThread(
-        distribution, coord, index, devices, variable_creator_fn, fn,
+        distribution, coord, index, device_map, variable_creator_fn, fn,
         values.select_replica(index, args),
         values.select_replica(index, kwargs))
     threads.append(t)
@@ -173,8 +173,10 @@ def _call_for_each_replica(distribution, devices, fn, args, kwargs):
             raise RuntimeError("Some replicas made a different number of "
                                "replica_context().merge_call() calls.")
           # get_replica_context().merge_call() case
-          merge_args = values.regroup(tuple(t.merge_args for t in threads))
-          merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads))
+          merge_args = values.regroup(
+              device_map, tuple(t.merge_args for t in threads))
+          merge_kwargs = values.regroup(
+              device_map, tuple(t.merge_kwargs for t in threads))
           # We capture the name_scope of the MRT when we call merge_fn
           # to ensure that if we have opened a name scope in the MRT,
           # it will be respected when executing the merge function. We only
@@ -198,7 +200,7 @@ def _call_for_each_replica(distribution, devices, fn, args, kwargs):
       t.should_run.set()
     coord.join(threads)
 
-  return values.regroup(tuple(t.main_result for t in threads))
+  return values.regroup(device_map, tuple(t.main_result for t in threads))
 
 
 def _is_device_list_single_worker(devices):
@@ -423,9 +425,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
 
   def _initialize_single_worker(self, devices):
     """Initializes the object for single-worker training."""
-    self._devices = tuple(device_util.canonicalize(d) for d in devices)
-    self._input_workers = input_lib.InputWorkers(
-        ((device_util.canonicalize("/device:CPU:0", devices[0]), devices),))
+    self._device_map = values.ReplicaDeviceMap(devices)
+    self._input_workers = input_lib.InputWorkers(self._device_map)
     self._inferred_cross_device_ops = None if self._cross_device_ops else (
         cross_device_ops_lib.choose_the_best(devices))
     self._host_input_device = numpy_dataset.SingleDevice(
@@ -460,8 +461,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     self._default_device = workers[0]
     self._host_input_device = numpy_dataset.SingleDevice(workers[0])
 
-    self._devices = tuple(devices)
-    self._input_workers = input_lib.InputWorkers(worker_devices)
+    self._device_map = values.ReplicaDeviceMap(devices)
+    self._input_workers = input_lib.InputWorkers(
+        self._device_map, worker_devices)
     self._is_multi_worker_training = True
 
     if len(workers) > 1:
@@ -506,14 +508,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     """Create a mirrored variable. See `DistributionStrategy.scope`."""
     colocate_with = kwargs.pop("colocate_with", None)
     if colocate_with is None:
-      devices = self._devices
+      device_map = self._device_map
+      logical_device = 0  # TODO(josh11b): Get logical device from scope here.
     elif isinstance(colocate_with, numpy_dataset.SingleDevice):
       with ops.device(colocate_with.device):
         return next_creator(*args, **kwargs)
     else:
-      devices = colocate_with.devices
+      device_map = colocate_with.device_map
+      logical_device = colocate_with.logical_device
 
-    def _real_mirrored_creator(*args, **kwargs):  # pylint: disable=g-missing-docstring
+    def _real_mirrored_creator(devices, *args, **kwargs):  # pylint: disable=g-missing-docstring
       value_list = []
       for i, d in enumerate(devices):
         with ops.device(d):
@@ -539,8 +543,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
       return value_list
 
     return values.create_mirrored_variable(
-        self._container_strategy(), _real_mirrored_creator,
-        values.MirroredVariable, values.SyncOnReadVariable, *args, **kwargs)
+        self._container_strategy(), device_map, logical_device,
+        _real_mirrored_creator, values.MirroredVariable,
+        values.SyncOnReadVariable, *args, **kwargs)
 
   def _validate_colocate_with_variable(self, colocate_with_variable):
     values.validate_colocate_distributed_variable(colocate_with_variable, self)
@@ -641,7 +646,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
       # For outputs that have already been reduced, wrap them in a Mirrored
       # container, else in a PerReplica container.
       if reduce_op is None:
-        last_step_tensor_outputs_dict[name] = values.regroup(output)
+        last_step_tensor_outputs_dict[name] = values.regroup(self._device_map,
+                                                             output)
       else:
         assert len(output) == 1
         last_step_tensor_outputs_dict[name] = output[0]
@@ -660,7 +666,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     # TODO(josh11b): In eager mode, use one thread per device, or async mode.
     if not destinations:
       # TODO(josh11b): Use current logical device instead of 0 here.
-      destinations = self._devices
+      destinations = values.LogicalDeviceSpec(
+          device_map=self._device_map, logical_device=0)
     return self._get_cross_device_ops().broadcast(tensor, destinations)
 
   def _call_for_each_replica(self, fn, args, kwargs):
@@ -683,7 +690,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
                           "`experimental_run_v2` inside a tf.function to get "
                           "the best performance." %
                           self._container_strategy().__class__.__name__, 5)
-    return _call_for_each_replica(self._container_strategy(), self._devices,
+    return _call_for_each_replica(self._container_strategy(), self._device_map,
                                   fn, args, kwargs)
 
   def _configure(self,
@@ -699,7 +706,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     if cluster_spec:
       # TODO(yuefengz): remove the following code once cluster_resolver is
       # added.
-      num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices)
+      num_gpus_per_worker = _infer_num_gpus_per_worker(
+          self._device_map.all_devices)
       multi_worker_devices = _cluster_spec_to_device_list(
           cluster_spec, num_gpus_per_worker)
       self._initialize_multi_worker(multi_worker_devices)
@@ -723,7 +731,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
       # replicas in which case `value` would be a single value or value could
       # be 0.
       return cross_device_ops_lib.reduce_non_distributed_value(
-          reduce_op, value, destinations, self._num_replicas_in_sync)
+          reduce_op, self._device_map, value, destinations)
     return self._get_cross_device_ops().reduce(
         reduce_op, value, destinations=destinations)
 
@@ -735,16 +743,14 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     # TODO(josh11b): In eager mode, use one thread per device.
     assert isinstance(var, values.DistributedVariable)
     updates = []
-    for i, v in enumerate(var.values):
+    for i, (d, v) in enumerate(zip(var.devices, var.values)):
       name = "update_%d" % i
-      with ops.device(v.device), \
-           distribute_lib.UpdateContext(i), \
-           ops.name_scope(name):
+      with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
         # If args and kwargs are not mirrored, the value is returned as is.
         updates.append(fn(v,
-                          *values.select_replica_mirrored(i, args),
-                          **values.select_replica_mirrored(i, kwargs)))
-    return values.update_regroup(self, updates, group)
+                          *values.select_device_mirrored(d, args),
+                          **values.select_device_mirrored(d, kwargs)))
+    return values.update_regroup(self, self._device_map, updates, group)
 
   def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
     assert isinstance(colocate_with, tuple)
@@ -753,9 +759,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
     for i, d in enumerate(colocate_with):
       name = "update_%d" % i
       with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
-        updates.append(fn(*values.select_replica_mirrored(i, args),
-                          **values.select_replica_mirrored(i, kwargs)))
-    return values.update_regroup(self, updates, group)
+        updates.append(fn(*values.select_device_mirrored(d, args),
+                          **values.select_device_mirrored(d, kwargs)))
+    return values.update_regroup(self, self._device_map, updates, group)
 
   def read_var(self, replica_local_var):
     """Read the aggregate value of a replica-local variable."""
@@ -774,19 +780,19 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
 
   @property
   def _num_replicas_in_sync(self):
-    return len(self._devices)
+    return self._device_map.num_replicas_in_graph
 
   @property
   def worker_devices(self):
-    return self._devices
+    return self._device_map.all_devices
 
   @property
   def worker_devices_by_replica(self):
-    return [[d] for d in self._devices]
+    return self._device_map.devices_by_replica
 
   @property
   def parameter_devices(self):
-    return self.worker_devices
+    return self._device_map.all_devices
 
   @property
   def experimental_between_graph(self):
@@ -807,7 +813,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
   def non_slot_devices(self, var_list):
     del var_list
     # TODO(josh11b): Should this be the last logical device instead?
-    return self._devices
+    return self._device_map.logical_to_actual_devices(0)
 
   # TODO(priyag): Delete this once all strategies use global batch size.
   @property
@@ -829,12 +835,12 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
 class _MirroredReplicaThread(threading.Thread):
   """A thread that runs() a function on a device."""
 
-  def __init__(self, dist, coord, replica_id, devices, variable_creator_fn,
+  def __init__(self, dist, coord, replica_id, device_map, variable_creator_fn,
                fn, args, kwargs):
     super(_MirroredReplicaThread, self).__init__()
     self.coord = coord
     self.distribution = dist
-    self.devices = devices
+    self.device_map = device_map
     self.replica_id = replica_id
     self.variable_creator_fn = variable_creator_fn
     # State needed to run and return the results of `fn`.
@@ -902,7 +908,8 @@ class _MirroredReplicaThread(threading.Thread):
           context.device_policy(self.context_device_policy), \
           MirroredReplicaContext(self.distribution, constant_op.constant(
               self.replica_id, dtypes.int32)), \
-          ops.device(self.devices[self.replica_id]), \
+          ops.device(self.device_map.logical_to_actual_devices(0)[
+              self.replica_id]), \
           ops.name_scope(self._name_scope), \
           variable_scope.variable_scope(
               self._var_scope, reuse=self.replica_id > 0), \
diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py
index d2bc7ae7285..32966c904d8 100644
--- a/tensorflow/python/distribute/mirrored_strategy_test.py
+++ b/tensorflow/python/distribute/mirrored_strategy_test.py
@@ -708,21 +708,13 @@ class MirroredVariableUpdateTest(test.TestCase):
       mirrored_var_result = self.evaluate(
           mirrored_var.assign_add(6.0, read_value=True))
       self.assertEqual(7.0, mirrored_var_result)
-      self.assertEqual(7.0, self.evaluate(mirrored_var.values[0]))
-      self.assertEqual(7.0, self.evaluate(mirrored_var.values[1]))
-      self.assertEqual(
-          distribution.extended.worker_devices[0], mirrored_var.devices[0])
-      self.assertEqual(
-          distribution.extended.worker_devices[1], mirrored_var.devices[1])
+      self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+      self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
 
       # read_value == False
       self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
-      self.assertEqual(9.0, self.evaluate(mirrored_var.values[0]))
-      self.assertEqual(9.0, self.evaluate(mirrored_var.values[1]))
-      self.assertEqual(
-          distribution.extended.worker_devices[0], mirrored_var.devices[0])
-      self.assertEqual(
-          distribution.extended.worker_devices[1], mirrored_var.devices[1])
+      self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+      self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
 
   def testAssignAddMirroredVarReplicaContext(self, distribution):
     def var_fn():
@@ -774,12 +766,8 @@ class MirroredVariableUpdateTest(test.TestCase):
       self.assertEqual(5.0, self.evaluate(mirrored_var))
       mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
       self.assertEqual(3.0, mirrored_var_result)
-      self.assertEqual(3.0, self.evaluate(mirrored_var.values[0]))
-      self.assertEqual(3.0, self.evaluate(mirrored_var.values[1]))
-      self.assertEqual(
-          distribution.extended.worker_devices[0], mirrored_var.devices[0])
-      self.assertEqual(
-          distribution.extended.worker_devices[1], mirrored_var.devices[1])
+      self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+      self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
 
   def testAssignSubMirroredVarReplicaContext(self, distribution):
     def var_fn():
@@ -990,8 +978,8 @@ class MirroredStrategyDefunTest(test.TestCase):
         per_replica_graph_functions = (
             distribution.extended.call_for_each_replica(
                 defun.get_concrete_function, args=[mock_model] + inputs))
-        for i in range(len(devices)):
-          graph_function = per_replica_graph_functions.values[i]
+        for device in devices:
+          graph_function = per_replica_graph_functions.get(device=device)
           # TODO(b/129555712): re-enable an assertion here that the two sets of
           # variables are the same.
           # self.assertEqual(set(graph_function.graph.variables),
@@ -1062,8 +1050,9 @@ class MirroredStrategyDefunTest(test.TestCase):
     def fn1(mock_model, factor):
       return mock_model(factor)
 
-    factors = values.PerReplica((5.0, 3.0))
-    expected_result = values.PerReplica((5.0 * 1.25, 3.0 * 1.25))
+    device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
+    factors = values.PerReplica(device_map, (5.0, 3.0))
+    expected_result = values.PerReplica(device_map, (5.0 * 1.25, 3.0 * 1.25))
     self._call_and_check(distribution, fn1, [factors], expected_result, [fn1])
 
   def testTrain(self, distribution):
diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py
index 3cc75451827..f237ee19205 100644
--- a/tensorflow/python/distribute/mirrored_variable_test.py
+++ b/tensorflow/python/distribute/mirrored_variable_test.py
@@ -87,9 +87,9 @@ class MirroredVariableCreationTest(test.TestCase):
     self.assertIsInstance(var, values.MirroredVariable)
     self.assertEqual(name, var.name)
     self.assertIs(strategy, var.distribute_strategy)
-    for i, d in enumerate(var.devices):
-      self.assertEqual(d, var.values[i].device)
-      self.assertIs(strategy, var.values[i]._distribute_strategy)  # pylint: disable=protected-access
+    for d in var.devices:
+      self.assertEqual(d, var.get(d).device)
+      self.assertIs(strategy, var.get(d)._distribute_strategy)  # pylint: disable=protected-access
 
   def testVariableInFuncGraph(self, distribution):
 
@@ -323,15 +323,16 @@ class MirroredVariableCreationTest(test.TestCase):
           aggregation=aggregation)
       return v0, v1
 
+    devices = distribution.extended.worker_devices
     with distribution.scope():
       v0, v1 = distribution.extended.call_for_each_replica(create_fn)
       self.evaluate(v0.initializer)
-      self.assertEqual(2.0, self.evaluate(v0.values[0]))
-      self.assertEqual(2.0, self.evaluate(v0.values[1]))
+      self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
+      self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
       self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
       self.evaluate(v1.initializer)
-      self.assertEqual(3.0, self.evaluate(v1.values[0]))
-      self.assertEqual(3.0, self.evaluate(v1.values[1]))
+      self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
+      self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
       self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
 
       def replica_id_plus_one():
@@ -348,20 +349,20 @@ class MirroredVariableCreationTest(test.TestCase):
 
       # Update "sync on read" variable.
       self.evaluate(distribution.group(update0a))
-      self.assertEqual(2.0 + 5.0, self.evaluate(v0.values[0]))
+      self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
       # Writes are not synchronized for "sync on read" variables,
       # so device[1] can end up with a different value.
-      self.assertEqual(2.0 + 2 * 5.0, self.evaluate(v0.values[1]))
+      self.assertEqual(2.0 + 2 * 5.0, self.evaluate(v0.get(devices[1])))
       # Always reads from device 0.
       self.assertEqual(2.0 + 5.0,
                        self.evaluate(distribution.extended.read_var(v0)))
 
       # Update "sync on write" variable.
       self.evaluate(distribution.group(update1a))
-      self.assertEqual(3.0 + 7.0, self.evaluate(v1.values[0]))
+      self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
       # Writes are synchronized for v1, only the argument to assign_add on
       # device[0] is used.
-      self.assertEqual(3.0 + 7.0, self.evaluate(v1.values[1]))
+      self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
       self.assertEqual(3.0 + 7.0,
                        self.evaluate(distribution.extended.read_var(v1)))
 
@@ -376,15 +377,16 @@ class MirroredVariableCreationTest(test.TestCase):
       self.evaluate(distribution.group(update0b))
 
       # Update "sync on read" variable.
-      self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.values[0]))
-      self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0, self.evaluate(v0.values[1]))
+      self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
+      self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0,
+                       self.evaluate(v0.get(devices[1])))
       self.assertEqual(2.0 + 5.0 + 11.0,
                        self.evaluate(distribution.extended.read_var(v0)))
 
       # Update "sync on write" variable.
       self.evaluate(distribution.group(update1b))
-      self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.values[0]))
-      self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.values[1]))
+      self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
+      self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
       self.assertEqual(3.0 + 7.0 + 13.0,
                        self.evaluate(distribution.extended.read_var(v1)))
 
@@ -446,7 +448,8 @@ class MirroredVariableCreationTest(test.TestCase):
       return v
 
     with distribution.scope():
-      names = values.DistributedValues(("foo", "bar"))
+      device_map = values.ReplicaDeviceMap(distribution.extended.worker_devices)
+      names = values.DistributedValues(device_map, ("foo", "bar"))
       with self.assertRaises(RuntimeError):
         _ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
 
@@ -509,10 +512,10 @@ class MirroredVariableCreationTest(test.TestCase):
       ])
       expected_sum = 0.0
       expected_mean = 0.0
-      for i, _ in enumerate(distribution.extended.worker_devices):
+      for i, d in enumerate(distribution.extended.worker_devices):
         # Should see different values on different devices.
-        v_sum_value = self.evaluate(ret_v_sum.values[i].read_value())
-        v_mean_value = self.evaluate(ret_v_mean.values[i].read_value())
+        v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
+        v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
         expected = i + 3.0
         self.assertEqual(expected, v_sum_value)
         expected_sum += expected
@@ -575,7 +578,11 @@ class MirroredVariableCreationTest(test.TestCase):
       self.evaluate(variables.global_variables_initializer())
       # Assert that the aggregated value of the sync on read var is the sum
       # of the individual values before running the update ops.
-      self.assertEqual(1.0, self.evaluate(ret_v_sum.values[0].read_value()))
+      self.assertEqual(
+          1.0,
+          self.evaluate(
+              ret_v_sum.get(
+                  distribution.extended.worker_devices[0]).read_value()))
       self.assertEqual(2.0, self.evaluate(ret_v_sum))
 
       # Apply updates.
@@ -584,7 +591,11 @@ class MirroredVariableCreationTest(test.TestCase):
       self.evaluate(update_ops)
       # Assert that the aggregated value of the sync on read vars is the sum
       # of the individual values after running the update ops.
-      self.assertEqual(5.0, self.evaluate(ret_v_sum.values[0].read_value()))
+      self.assertEqual(
+          5.0,
+          self.evaluate(
+              ret_v_sum.get(
+                  distribution.extended.worker_devices[0]).read_value()))
       self.assertEqual(10.0, self.evaluate(ret_v_sum))
 
   def testVarDistributeStrategy(self, distribution):
diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py
index 144ce6a8fce..7963a23c20f 100644
--- a/tensorflow/python/distribute/one_device_strategy.py
+++ b/tensorflow/python/distribute/one_device_strategy.py
@@ -251,7 +251,9 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
     suffix_loc = self._device.rfind("/")
     self._input_device = self._device[:suffix_loc] + "/device:CPU:0"
     worker_device_pairs = [(self._input_device, [self._device])]
-    self._input_workers = input_lib.InputWorkers(worker_device_pairs)
+    device_map = values.SingleDeviceMap(self._device)
+    self._input_workers = input_lib.InputWorkers(
+        device_map, worker_device_pairs)
 
   def _create_variable(self, next_creator, *args, **kwargs):
     colocate_with = kwargs.pop("colocate_with", None)
diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py
index d5305ed910a..1815fc2a669 100644
--- a/tensorflow/python/distribute/parameter_server_strategy.py
+++ b/tensorflow/python/distribute/parameter_server_strategy.py
@@ -213,10 +213,9 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
     else:
       compute_devices = (worker_device,)
 
-    self._compute_devices = [
-        device_util.canonicalize(d) for d in compute_devices]
+    self._device_map = values.ReplicaDeviceMap(compute_devices)
     self._input_workers = input_lib.InputWorkers(
-        [(worker_device, compute_devices)])
+        self._device_map, [(worker_device, compute_devices)])
 
     # In distributed mode, place variables on ps jobs in a round-robin fashion.
     # Note that devices returned from `replica_device_setter` are not
@@ -254,9 +253,9 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
     logging.info(
         "Multi-worker ParameterServerStrategy with "
         "cluster_spec = %r, task_type = %r, task_id = %r, "
-        "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
+        "num_ps_replicas = %r, is_chief = %r, device_map = %r, "
         "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
-        num_ps_replicas, self._is_chief, self._compute_devices,
+        num_ps_replicas, self._is_chief, self._device_map,
         self._variable_device)
 
   # TODO(yuefengz): get rid of cluster_resolver argument when contrib's
@@ -280,8 +279,6 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
 
       compute_devices = device_util.local_devices_from_num_gpus(num_gpus)
 
-    compute_devices = [device_util.canonicalize(d) for d in compute_devices]
-
     if parameter_device is None:
       # If there is only one GPU, put everything on that GPU. Otherwise, place
       # variables on CPU.
@@ -290,11 +287,11 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
       else:
         parameter_device = _LOCAL_CPU
 
+    self._device_map = values.ReplicaDeviceMap(compute_devices)
     self._input_workers = input_lib.InputWorkers(
-        [(worker_device, compute_devices)])
+        self._device_map, [(worker_device, compute_devices)])
 
     self._variable_device = parameter_device
-    self._compute_devices = compute_devices
     self._parameter_devices = (parameter_device,)
     self._is_chief = True
     self._cluster_spec = None
@@ -379,7 +376,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
       return tensor
     if not cross_device_ops_lib.check_destinations(destinations):
       # TODO(josh11b): Use current logical device instead of 0 here.
-      destinations = self._compute_devices
+      destinations = values.LogicalDeviceSpec(
+          device_map=self._device_map, logical_device=0)
     return self._cross_device_ops.broadcast(tensor, destinations)
 
   def _allow_variable_partition(self):
@@ -451,7 +449,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
   def _call_for_each_replica(self, fn, args, kwargs):
     # pylint: disable=protected-access
     return mirrored_strategy._call_for_each_replica(
-        self._container_strategy(), self._compute_devices, fn, args, kwargs)
+        self._container_strategy(), self._device_map, fn, args, kwargs)
 
   def _verify_destinations_not_different_worker(self, destinations):
     if not self._cluster_spec:
@@ -470,7 +468,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
     if not isinstance(value, values.DistributedValues):
       # pylint: disable=protected-access
       return cross_device_ops_lib.reduce_non_distributed_value(
-          reduce_op, value, destinations, self._num_replicas_in_sync)
+          reduce_op, self._device_map, value, destinations)
     return self._cross_device_ops.reduce(
         reduce_op, value, destinations=destinations)
 
@@ -607,15 +605,15 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
 
   @property
   def _num_replicas_in_sync(self):
-    return len(self._compute_devices)
+    return self._device_map.num_replicas_in_graph
 
   @property
   def worker_devices(self):
-    return self._compute_devices
+    return self._device_map.all_devices
 
   @property
   def worker_devices_by_replica(self):
-    return [[d] for d in self._compute_devices]
+    return self._device_map.devices_by_replica
 
   @property
   def parameter_devices(self):
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 3987bf390ff..2dd4309537a 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -201,6 +201,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 
     self._host_device = device_util.get_host_for_device(self._tpu_devices[0])
 
+    self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
+
     # Preload the data onto the TPUs.
     input_worker_devices = collections.OrderedDict()
     for tpu_device in self._tpu_devices:
@@ -208,7 +210,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
       input_worker_devices.setdefault(host_device, [])
       input_worker_devices[host_device].append(tpu_device)
     self._input_workers = input_lib.InputWorkers(
-        tuple(input_worker_devices.items()))
+        self._device_map, tuple(input_worker_devices.items()))
 
     # TODO(sourabhbajaj): Remove this once performance of running one step
     # at a time is comparable to multiple steps.
@@ -393,14 +395,16 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 
     colocate_with = kwargs.pop("colocate_with", None)
     if colocate_with is None:
-      devices = self._tpu_devices
+      device_map = self._device_map
+      logical_device = 0  # TODO(josh11b): Get logical device from scope here.
     elif isinstance(colocate_with, numpy_dataset.SingleDevice):
       with ops.device(colocate_with.device):
         return next_creator(*args, **kwargs)
     else:
-      devices = colocate_with.devices
+      device_map = colocate_with.device_map
+      logical_device = colocate_with.logical_device
 
-    def _real_mirrored_creator(*args, **kwargs):  # pylint: disable=g-missing-docstring
+    def _real_mirrored_creator(devices, *args, **kwargs):  # pylint: disable=g-missing-docstring
       initial_value = None
       value_list = []
       for i, d in enumerate(devices):
@@ -430,9 +434,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
       return value_list
 
     return values.create_mirrored_variable(
-        self._container_strategy(), _real_mirrored_creator,
-        values.TPUMirroredVariable, values.TPUSyncOnReadVariable,
-        *args, **kwargs)
+        self._container_strategy(), device_map, logical_device,
+        _real_mirrored_creator, values.TPUMirroredVariable,
+        values.TPUSyncOnReadVariable, *args, **kwargs)
 
   def _reduce_to(self, reduce_op, value, destinations):
     if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
@@ -450,7 +454,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
       # replicas in which case `value` would be a single value or value could
       # be 0.
       return cross_device_ops_lib.reduce_non_distributed_value(
-          reduce_op, value, destinations, self._num_replicas_in_sync)
+          reduce_op, self._device_map, value, destinations)
 
     # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
     # Always performs the reduction on the TPU host.
@@ -486,16 +490,14 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
     # Otherwise, we revert to MirroredStrategy behavior and update each variable
     # directly.
     updates = []
-    for i, v in enumerate(var.values):
+    for i, (d, v) in enumerate(zip(var.devices, var.values)):
       name = "update_%d" % i
-      with ops.device(v.device), \
-           distribute_lib.UpdateContext(i), \
-           ops.name_scope(name):
+      with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
         # If args and kwargs are not mirrored, the value is returned as is.
         updates.append(fn(v,
-                          *values.select_replica_mirrored(i, args),
-                          **values.select_replica_mirrored(i, kwargs)))
-    return values.update_regroup(self, updates, group)
+                          *values.select_device_mirrored(d, args),
+                          **values.select_device_mirrored(d, kwargs)))
+    return values.update_regroup(self, self._device_map, updates, group)
 
   def read_var(self, var):
     assert isinstance(var, values.TPUVariableMixin) or isinstance(
@@ -704,7 +706,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
             nest.pack_sequence_as(result[0], nest.flatten(replica_output))
             for replica_output in replicate_outputs
         ]
-      return values.regroup(replicate_outputs)
+      device_map = self._device_map  # pylint: disable=protected-access
+      return values.regroup(device_map, replicate_outputs)
 
     if context.executing_eagerly():
       tpu_function = def_function.function(tpu_function)
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index df232545cfa..0c2a9ccdaac 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -21,6 +21,7 @@ from __future__ import print_function
 import collections
 import contextlib
 import weakref
+import six
 
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribute_lib
@@ -43,76 +44,325 @@ from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import nest
 
 
-def _get_current_replica_id_as_int():
-  """Returns the current replica ID as an integer, or `None`."""
-  replica_context = distribution_strategy_context.get_replica_context()
-  if replica_context:
+def _devices_match(d1, d2):
+  return device_util.canonicalize(d1) == device_util.canonicalize(d2)
+
+
+class DeviceMap(object):
+  """A mapping of replicas & logical device ids to devices."""
+
+  @property
+  def all_devices(self):
+    """Returns a tuple of strings with all devices in this DeviceMap."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  @property
+  def devices_by_replica(self):
+    """Returns a tuple `t` where `t[replica]` is the devices for `replica`."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  @property
+  def num_logical_devices(self):
+    """Count of the number of devices each replica may be defined across."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  @property
+  def num_replicas_in_graph(self):
+    """Number of replicas defined in this graph."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  def logical_device_from_values(self, values):
+    """Returns the logical device index `values` is on."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  def logical_to_actual_devices(self, logical_device_id):
+    """Returns sequence of `num_replicas_in_graph` devices."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  def select_for_current_replica(self, values, replica_context):
+    """Select the element of `values` for the current replica."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  def replica_for_device(self, device):
+    """Return the replica id containing `device`."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  def select_for_device(self, values, device):
+    """Select the element of `values` to access from `device`."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+  def is_device_in_replica(self, device, replica_id):
+    """Returns whether `device` is a member of replica `replica_id`."""
+    raise NotImplementedError("Required for DeviceMap implementations.")
+
+
+class SingleDeviceMap(DeviceMap):
+  """A device map for 1 non-computation device.
+
+  Use `SingleDeviceMap` when the device does not correspond to some replica of
+  the computation. For computation devices, use `ReplicaDeviceMap` below (even
+  if there is only a single device in the map).
+  """
+
+  def __init__(self, device):
+    """Initialize a `SingleDeviceMap`.
+
+    Args:
+      device: A string device.
+    """
+    assert isinstance(device, six.string_types)
+    self._device = device_util.canonicalize(device)
+    self._devices = (self._device,)
+
+  @property
+  def all_devices(self):
+    return self._devices
+
+  @property
+  def devices_by_replica(self):
+    raise ValueError("SingleDeviceMap not indexed by replicas")
+
+  @property
+  def num_logical_devices(self):
+    return 1
+
+  @property
+  def num_replicas_in_graph(self):
+    return 1
+
+  def logical_device_from_values(self, values):
+    del values
+    return 0
+
+  def logical_to_actual_devices(self, logical_device_id):
+    assert logical_device_id == 0
+    return self._devices
+
+  def select_for_current_replica(self, values, replica_context):
+    assert len(values) == 1
+    del replica_context
+    return values[0]
+
+  def replica_for_device(self, device):
+    raise ValueError("SingleDeviceMap not indexed by replicas")
+
+  def select_for_device(self, values, device):
+    assert len(values) == 1
+    if self._device != device:
+      raise ValueError("Device %s not found in %s (current device %s)" %
+                       (device, self._devices, device_util.current()))
+    return values[0]
+
+  def is_device_in_replica(self, device, replica_id):
+    raise ValueError("SingleDeviceMap not indexed by replicas")
+
+  def __repr__(self):
+    return "%s(%r)" % (self.__class__.__name__, self._device)
+
+
+class ReplicaDeviceMap(DeviceMap):
+  """A device map for 1 device per replica."""
+
+  def __init__(self, devices):
+    """Initialize a `ReplicaDeviceMap`.
+
+    Args:
+      devices: `devices[i]` is the string device for replica `i`.
+    """
+    self._devices = tuple(device_util.canonicalize(d) for d in devices)
+    if len(set(self._devices)) != len(self._devices):
+      raise ValueError("Duplicate devices in %s, after canonicalization: %s" %
+                       (devices, self._devices))
+    self._device_to_replica = {d: r for r, d in enumerate(self._devices)}
+
+  @property
+  def all_devices(self):
+    return self._devices
+
+  @property
+  def devices_by_replica(self):
+    return ((d,) for d in self._devices)
+
+  @property
+  def num_logical_devices(self):
+    return 1
+
+  @property
+  def num_replicas_in_graph(self):
+    return len(self._devices)
+
+  def logical_device_from_values(self, values):
+    del values
+    return 0
+
+  def logical_to_actual_devices(self, logical_device_id):
+    assert logical_device_id == 0
+    return self._devices
+
+  def select_for_current_replica(self, values, replica_context):
+    assert len(values) == len(self._devices)
     replica_id = replica_context.replica_id_in_sync_group
     if not isinstance(replica_id, int):
       replica_id = tensor_util.constant_value(replica_id)
-  else:
-    replica_id = distribute_lib.get_update_replica_id()
-  return replica_id
+    if replica_id is None:
+      replica_id = 0
+    return values[replica_id]
+
+  def replica_for_device(self, device):
+    return self._device_to_replica.get(device)
+
+  def select_for_device(self, values, device):
+    assert len(values) == len(self._devices)
+    replica_id = self._device_to_replica.get(device)
+    if replica_id is None:
+      raise ValueError("Device %s not found in %s (current device %s)" %
+                       (device, self._devices, device_util.current()))
+    return values[replica_id]
+
+  def is_device_in_replica(self, device, replica_id):
+    return _devices_match(device, self._devices[replica_id])
+
+  def __str__(self):
+    return "[%s]" % (", ".join(self._devices))
+
+  def __repr__(self):
+    return "%s([%s])" % (self.__class__.__name__, ", ".join(
+        repr(d) for d in self._devices))
+
+
+LogicalDeviceSpec = collections.namedtuple("LogicalDeviceSpec",
+                                           ("device_map", "logical_device"))
+
+
+class WorkerDeviceMap(DeviceMap):
+  """A device map for one value per worker."""
+
+  def __init__(self, devices, num_replicas_per_worker):
+    """Initialize a `WorkerDeviceMap`.
+
+    Args:
+      devices: `devices[i]` is the string device for worker `i` in in-graph
+        relication case; devices is single-element list for its corresponding
+        worker in between-graph case.
+      num_replicas_per_worker: number of replicas per worker, useful in in-graph
+        replication case.
+    """
+    self._devices = tuple(device_util.canonicalize(d) for d in devices)
+    if len(set(self._devices)) != len(self._devices):
+      raise ValueError("Duplicate devices in %s, after canonicalization: %s" %
+                       (devices, self._devices))
+    self._num_replicas_per_worker = num_replicas_per_worker
+
+  @property
+  def all_devices(self):
+    return self._devices
+
+  @property
+  def devices_by_replica(self):
+    raise ValueError("`WorkerDeviceMap` is not indexed by replicas")
+
+  @property
+  def num_logical_devices(self):
+    return 1
+
+  @property
+  def num_replicas_in_graph(self):
+    return len(self._devices)
+
+  def logical_device_from_values(self, values):
+    del values
+    return 0
+
+  def logical_to_actual_devices(self, logical_device_id):
+    assert logical_device_id == 0
+    return self._devices
+
+  def select_for_current_replica(self, values, replica_context):
+    return values[replica_context.replica_id_in_sync_group //
+                  self._num_replicas_per_worker]
+
+  def replica_for_device(self, device):
+    raise ValueError("`WorkerDeviceMap` not indexed by replicas")
+
+  def select_for_device(self, values, device):
+    # TODO(yuefengz): this should map from any device to the value on its
+    # corresponding worker.
+    return values[self._devices.index(device_util.canonicalize(device))]
+
+  def is_device_in_replica(self, device, replica_id):
+    raise ValueError("WorkerDeviceMap not indexed by replicas")
+
+  def __repr__(self):
+    return "%s(%r, num_replicas_per_worker=%d)" % (
+        self.__class__.__name__, self._devices, self._num_replicas_per_worker)
 
 
 class DistributedValues(object):
   """Holds a map from replica to values. Either PerReplica or Mirrored."""
 
-  def __init__(self, values):
+  def __init__(self, device_map, values, logical_device=None):
+    assert isinstance(device_map, DeviceMap)
+    self._device_map = device_map
     self._values = tuple(values)
+    if logical_device is None:
+      logical_device = device_map.logical_device_from_values(self._values)
+    self._logical_device = logical_device
 
-  def get(self):
+  # TODO(josh11b): Split this into two functions, one with device, one without.
+  def get(self, device=None):
     """Returns the value for the current device or raises a ValueError."""
-    replica_id = _get_current_replica_id_as_int()
-    if replica_id is None:
-      return self._get_cross_replica()
-    else:
-      return self._values[replica_id]
-
-  def _get_cross_replica(self):
-    raise NotImplementedError(
-        "This method should be overridden by sub-classes which support cross-"
-        "replica accesses.")
-
-  def _get_closest(self):
-    """Returns value in same replica or device if possible, else the primary."""
-    replica_id = _get_current_replica_id_as_int()
-    if replica_id is None:
-      # Try to find a value on the current device.
-      current_device = device_util.canonicalize(device_util.current())
-      for value in self._values:
-        if device_util.canonicalize(value.device) == current_device:
-          return value
-      return self.primary
-    else:
-      return self._values[replica_id]
+    if device is None:
+      replica_context = distribution_strategy_context.get_replica_context()
+      if replica_context:
+        return self._device_map.select_for_current_replica(
+            self._values, replica_context)
+      else:
+        update_replica_id = distribute_lib.get_update_replica_id()
+        if update_replica_id is None:
+          return self._get_cross_replica()
+        else:
+          return self._values[update_replica_id]
+    device = device_util.canonicalize(device)
+    return self._device_map.select_for_device(self._values, device)
 
   @property
   def primary(self):
     """Returns a representative component."""
     return self._values[0]
 
+  @property
+  def devices(self):
+    return self._device_map.logical_to_actual_devices(self._logical_device)
+
+  @property
+  def logical_device(self):
+    return self._logical_device
+
+  @property
+  def device_map(self):
+    return self._device_map
+
   # TODO(josh11b): Replace experimental_local_results with this?
   @property
   def values(self):
     return self._values
 
-  @property
-  def devices(self):
-    return tuple(v.device for v in self._values)
-
   @property
   def is_tensor_like(self):
     return all(tensor_util.is_tensor(v) for v in self._values)
 
   def __str__(self):
-    debug_str = ",\n".join(
-        "  %d: %s" % (i, v) for i, v in enumerate(self._values))
+    devices = self.devices
+    assert len(self._values) == len(devices)
+    debug_str = ",\n".join("  %d %s: %s" % (i, devices[i], self._values[i])
+                           for i in range(len(devices)))
     return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
 
   def __repr__(self):
-    debug_repr = ",\n".join(
-        "  %d: %r" % (i, v) for i, v in enumerate(self._values))
+    devices = self.devices
+    assert len(self._values) == len(devices)
+    debug_repr = ",\n".join("  %d %s: %r" % (i, devices[i], self._values[i])
+                            for i in range(len(devices)))
     return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
 
 
@@ -273,22 +523,28 @@ class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
 
   @property
   def _type_spec(self):
-    return PerReplicaSpec(
-        *(type_spec.type_spec_from_value(v) for v in self._values))
+    value_specs = nest.map_structure(type_spec.type_spec_from_value,
+                                     self._values)
+    return PerReplicaSpec(value_specs, self._device_map, self._logical_device)
 
 
 class PerReplicaSpec(type_spec.TypeSpec):
   """Type specification for a `PerReplica`."""
 
-  __slots__ = ["_value_specs"]
+  __slots__ = ["_value_specs", "_device_map", "_logical_device"]
 
   value_type = property(lambda self: PerReplica)
 
-  def __init__(self, *value_specs):
+  def __init__(self, value_specs, device_map, logical_device):
+    if isinstance(device_map, tuple):
+      device_map = self._deserialize_device_map(device_map)
     self._value_specs = tuple(value_specs)
+    self._device_map = device_map
+    self._logical_device = logical_device
 
   def _serialize(self):
-    return self._value_specs
+    device_map = self._serialize_device_map(self._device_map)
+    return (self._value_specs, device_map, self._logical_device)
 
   @property
   def _component_specs(self):
@@ -303,7 +559,34 @@ class PerReplicaSpec(type_spec.TypeSpec):
     return value._values  # pylint: disable=protected-access
 
   def _from_components(self, tensor_list):
-    return PerReplica(tensor_list)
+    return PerReplica(
+        self._device_map, tensor_list, logical_device=self._logical_device)
+
+  @staticmethod
+  def _serialize_device_map(device_map):
+    if isinstance(device_map, SingleDeviceMap):
+      return ("single", device_map.all_devices[0])
+    elif isinstance(device_map, ReplicaDeviceMap):
+      return ("replica", device_map.all_devices)
+    elif isinstance(device_map, WorkerDeviceMap):
+      return ("worker", device_map.all_devices,
+              device_map.num_replicas_per_worker)
+    else:
+      raise ValueError("PerReplicaSpec does not support device_map type %s" %
+                       type(device_map).__name__)
+
+  @staticmethod
+  def _deserialize_device_map(device_map_info):
+    device_map_type = device_map_info[0]
+    device_map_args = device_map_info[1:]
+    if device_map_type == "single":
+      return SingleDeviceMap(*device_map_args)
+    elif device_map_type == "replica":
+      return ReplicaDeviceMap(*device_map_args)
+    elif device_map_type == "worker":
+      return WorkerDeviceMap(*device_map_args)
+    else:
+      raise ValueError("Unexpected value in state tuple")
 
 
 # Note that unlike PerReplica, Mirrored values inherit from
@@ -313,7 +596,11 @@ class Mirrored(DistributedDelegate):
   """Holds a map from replica to values which are kept in sync."""
 
   def _get_cross_replica(self):
-    return self._get_closest()
+    device = device_util.canonicalize(device_util.current())
+    replica_id = self._device_map.replica_for_device(device)
+    if replica_id is None:
+      return self.primary
+    return self._values[replica_id]
 
   def _as_graph_element(self):
     obj = self.get()
@@ -369,9 +656,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
   # TODO(josh11b): Support changing the set of variables if e.g. if new
   # devices are joining or a device is to leave.
 
-  def __init__(self, strategy, values):
+  def __init__(self, strategy, device_map, values, logical_device=None):
     self._distribute_strategy = strategy
-    super(DistributedVariable, self).__init__(values)
+    super(DistributedVariable, self).__init__(
+        device_map, values, logical_device=logical_device)
     self._common_name = self.primary.name.split(":")[0]
     # Use a weakref to make it easy to map from the contained values
     # to the container without introducing a reference cycle.
@@ -421,6 +709,21 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
           tuple(v.initializer for v in self._values))
     return init_op
 
+  def _get_closest(self):
+    """Return member in the same replica if possible, else the primary."""
+    replica_context = distribution_strategy_context.get_replica_context()
+    if replica_context:
+      return self._device_map.select_for_current_replica(
+          self._values, replica_context)
+    update_replica_id = distribute_lib.get_update_replica_id()
+    if update_replica_id is not None:
+      return self._values[update_replica_id]
+    device = device_util.canonicalize(device_util.current())
+    replica_id = self._device_map.replica_for_device(device)
+    if replica_id is None:
+      return self.primary
+    return self._values[replica_id]
+
   def initialized_value(self):
     return self._get_closest().initialized_value()
 
@@ -463,12 +766,14 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
 
   @property
   def handle(self):
-    replica_id = _get_current_replica_id_as_int()
-    if replica_id is None:
-      raise ValueError("`handle` is not available outside the replica context"
-                       " or a `tf.distribute.Strategy.update()` call.")
-    else:
-      return self._values[replica_id].handle
+    replica_context = distribution_strategy_context.get_replica_context()
+    if replica_context is None:
+      update_replica_id = distribute_lib.get_update_replica_id()
+      if update_replica_id is None:
+        raise ValueError("`handle` is not available outside the replica context"
+                         " or a `tf.distribute.Strategy.update()` call.")
+      return self._values[update_replica_id].handle
+    return self.get().handle
 
   def eval(self, session=None):
     return self._get_closest().eval(session)
@@ -578,9 +883,9 @@ class TPUVariableMixin(object):
       raise AttributeError(
           "'{}' not accessible within a TPU context.".format(name))
 
-  def get(self):
-    if _enclosing_tpu_context() is None:
-      return super(TPUVariableMixin, self).get()
+  def get(self, device=None):
+    if (_enclosing_tpu_context() is None) or (device is not None):
+      return super(TPUVariableMixin, self).get(device=device)
     else:
       raise NotImplementedError(
           "`TPUVariableMixin.get()` is not supported within a TPU context.")
@@ -612,8 +917,10 @@ class TPUVariableMixin(object):
     if tpu_context is None:
       return self._get_closest().handle
     else:
-      return tpu_context.get_replicated_var_handle(
-          self._handle_id, self._values, self._is_mirrored())
+      return tpu_context.get_replicated_var_handle(self._handle_id,
+                                                   self._values,
+                                                   self._device_map,
+                                                   self._is_mirrored())
 
   @property
   def device(self):
@@ -728,8 +1035,8 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
 
 
 def create_mirrored_variable(  # pylint: disable=missing-docstring
-    strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls,
-    *args, **kwargs):
+    strategy, device_map, logical_device, real_mirrored_creator, mirrored_cls,
+    sync_on_read_cls, *args, **kwargs):
   # Figure out what collections this variable should be added to.
   # We'll add the MirroredVariable to those collections instead.
   var_collections = kwargs.pop("collections", None)
@@ -772,9 +1079,17 @@ def create_mirrored_variable(  # pylint: disable=missing-docstring
   # was never recorded on the tape instead of having to do this manually
   # here.
   with tape.stop_recording():
-    value_list = real_mirrored_creator(*args, **kwargs)
+    devices = device_map.logical_to_actual_devices(logical_device)
+    value_list = real_mirrored_creator(devices, *args, **kwargs)
+
     var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
-    result = var_cls(strategy, value_list, aggregation)
+
+    result = var_cls(
+        strategy,
+        device_map,
+        value_list,
+        aggregation,
+        logical_device=logical_device)
 
   # Add the wrapped variable to the requested collections.
   # The handling of eager mode and the global step matches
@@ -805,8 +1120,14 @@ def create_mirrored_variable(  # pylint: disable=missing-docstring
 class MirroredVariable(DistributedVariable, Mirrored):
   """Holds a map from replica to variables whose values are kept in sync."""
 
-  def __init__(self, strategy, values, aggregation):
-    super(MirroredVariable, self).__init__(strategy, values)
+  def __init__(self,
+               strategy,
+               device_map,
+               values,
+               aggregation,
+               logical_device=None):
+    super(MirroredVariable, self).__init__(
+        strategy, device_map, values, logical_device=logical_device)
     self._aggregation = aggregation
 
   # The arguments to update() are automatically unwrapped so the update()
@@ -866,12 +1187,17 @@ class MirroredVariable(DistributedVariable, Mirrored):
     return self._aggregation
 
   def _get_cross_replica(self):
-    # Return identity, to avoid directly exposing the variable to the user and
-    # allowing it to be modified by mistake.
-    return array_ops.identity(Mirrored._get_cross_replica(self))
+    device = device_util.canonicalize(device_util.current())
+    replica_id = self._device_map.replica_for_device(device)
+    if replica_id is None:
+      return array_ops.identity(self.primary)
+    return array_ops.identity(self._values[replica_id])
 
   def _as_graph_element(self):
-    return self._get_closest()._as_graph_element()  # pylint: disable=protected-access
+    # pylint: disable=protected-access
+    if distribution_strategy_context.in_cross_replica_context():
+      return self.primary._as_graph_element()
+    return self.get()._as_graph_element()
 
   def _gather_saveables_for_checkpoint(self):
     """Overrides Trackable method.
@@ -1018,9 +1344,15 @@ def _assert_replica_context(strategy):
 class SyncOnReadVariable(DistributedVariable):
   """Holds a map from replica to variables whose values are reduced on save."""
 
-  def __init__(self, strategy, values, aggregation):
-    super(SyncOnReadVariable, self).__init__(strategy, values)
+  def __init__(self,
+               strategy,
+               device_map,
+               values,
+               aggregation,
+               logical_device=None):
     self._aggregation = aggregation
+    super(SyncOnReadVariable, self).__init__(
+        strategy, device_map, values, logical_device=logical_device)
 
   def assign_sub(self, *args, **kwargs):
     with _enter_or_assert_strategy(self._distribute_strategy):
@@ -1060,7 +1392,7 @@ class SyncOnReadVariable(DistributedVariable):
         # when saving.
         tensor = args[0]
         if self._aggregation == vs.VariableAggregation.SUM:
-          tensor = math_ops.cast(tensor / len(self._values), self.dtype)
+          tensor = math_ops.cast(tensor / len(self.devices), self.dtype)
         return control_flow_ops.group(
             tuple(_assign_on_device(v.device, v, tensor) for v in self._values))
       else:
@@ -1147,8 +1479,10 @@ class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
     return False
 
 
-def regroup(values, wrap_class=PerReplica):
+def regroup(device_map, values, wrap_class=PerReplica):
   """Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
+  assert isinstance(device_map, DeviceMap)
+  assert len(values) == device_map.num_replicas_in_graph
   v0 = values[0]
 
   if isinstance(v0, list):
@@ -1157,7 +1491,8 @@ def regroup(values, wrap_class=PerReplica):
       assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
                                  (len(v), len(v0), v, v0))
     return [
-        regroup(tuple(v[i] for v in values), wrap_class)
+        regroup(device_map, tuple(v[i]
+                                  for v in values), wrap_class)
         for i in range(len(v0))
     ]
 
@@ -1166,7 +1501,8 @@ def regroup(values, wrap_class=PerReplica):
       assert isinstance(v, tuple)
       assert len(v) == len(v0)
     regrouped_tuple = tuple(
-        regroup(tuple(v[i] for v in values), wrap_class)
+        regroup(device_map, tuple(v[i]
+                                  for v in values), wrap_class)
         for i in range(len(v0)))
     if hasattr(v0, "_fields"):
       # This tuple is in fact a namedtuple! Create a new namedtuple instance
@@ -1183,7 +1519,7 @@ def regroup(values, wrap_class=PerReplica):
       assert set(v.keys()) == v0keys, ("v[0].keys: %s  v[i].keys: %s" %
                                        (v0keys, set(v.keys())))
     return {
-        key: regroup(tuple(v[key] for v in values), wrap_class)
+        key: regroup(device_map, tuple(v[key] for v in values), wrap_class)
         for key in v0keys
     }
 
@@ -1219,14 +1555,20 @@ def regroup(values, wrap_class=PerReplica):
     # pylint: disable=protected-access
     assert not isinstance(v0, MirroredVariable), (
         "ids = %s, values = %s" % ([id(v) for v in values], values))
+    assert device_map.is_device_in_replica(
+        v0.device,
+        0), ("v0.device = %s, device_map = %s" % (v0.device, device_map))
     distributed_container = v0._distributed_container()
     assert distributed_container is not None
-    for v in values[1:]:
+    for r, v in enumerate(values[1:]):
+      assert device_map.is_device_in_replica(
+          v.device, r + 1), ("v.device = %s, r = %d, device_map = %s" %
+                             (v.device, r + 1, device_map))
       assert distributed_container is v._distributed_container()
     return distributed_container
   # pylint: enable=protected-access
 
-  return wrap_class(values)
+  return wrap_class(device_map, values)
 
 
 def select_replica(replica_id, structured):
@@ -1245,8 +1587,8 @@ def select_replica(replica_id, structured):
   return nest.map_structure(_get, structured)
 
 
-def select_replica_mirrored(replica_id, structured):
-  """Specialize a nest of regular & mirrored values for one replica."""
+def select_device_mirrored(device, structured):
+  """Specialize a nest of regular & mirrored values for one device."""
 
   def _get_mirrored(x):
     if isinstance(x, DistributedValues):
@@ -1254,23 +1596,23 @@ def select_replica_mirrored(replica_id, structured):
         raise TypeError(
             "Expected value to be mirrored across replicas: %s in %s." %
             (x, structured))
-      return x.values[replica_id]
+      return x.get(device)
     else:
       return x
 
   return nest.map_structure(_get_mirrored, structured)
 
 
-def update_regroup(extended, updates, group):
+def update_regroup(extended, device_map, updates, group):
   """Regroup for an update, with dependencies to ensure all updates execute."""
   if not group:
-    regrouped = regroup(updates, Mirrored)
+    regrouped = regroup(device_map, updates, Mirrored)
     return nest.map_structure(extended._local_results, regrouped)  # pylint: disable=protected-access
 
-  def _make_grouped_mirrored(values):
+  def _make_grouped_mirrored(device_map, values):
     """Convert per-replica list `values` into Mirrored type with grouping."""
     if len(values) == 1:
-      return Mirrored(values)
+      return Mirrored(device_map, values)
 
     # Make sure we run all updates. Without this, something like
     # session.run(extended.update(...)) may only update one replica.
@@ -1284,14 +1626,17 @@ def update_regroup(extended, updates, group):
 
     # Otherwise we need tensors with the same values as `values`, but
     # that have a dependency on `g`.
+    devices = device_map.logical_to_actual_devices(
+        device_map.logical_device_from_values(values))
+    assert len(values) == len(devices)
     with_dep = []
-    for v in values:
-      with ops.device(v.device), ops.control_dependencies([g]):
+    for v, d in zip(values, devices):
+      with ops.device(d), ops.control_dependencies([g]):
         with_dep.append(array_ops.identity(v))
 
-    return Mirrored(with_dep)
+    return Mirrored(device_map, with_dep)
 
-  return regroup(updates, _make_grouped_mirrored)
+  return regroup(device_map, updates, _make_grouped_mirrored)
 
 
 def value_container(val):
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index 01022b6e110..d97d1155c82 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -24,7 +24,7 @@ import os
 from absl.testing import parameterized
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.distribute import combinations
-from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import tpu_strategy
@@ -55,35 +55,63 @@ from tensorflow.python.util import nest
 class DistributedValuesTest(test.TestCase):
 
   def testGetEager(self):
-    one = constant_op.constant(1)
-    two = constant_op.constant(2)
-    v = values.DistributedValues((one, two))
-    self.assertEqual(one, v.get())
-    with distribute_lib.ReplicaContext(None, 1):
-      self.assertEqual(two, v.get())
+    with ops.device("/device:CPU:0"):
+      one = constant_op.constant(1)
+      two = constant_op.constant(2)
+      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
+      v = values.DistributedValues(device_map, (one, two))
+      self.assertEqual(two, v.get("/device:GPU:0"))
+      self.assertEqual(one, v.get())
+      with self.assertRaises(ValueError):
+        self.assertIsNone(v.get("/device:GPU:2"))
 
   def testGetGraph(self):
-    with context.graph_mode(), ops.Graph().as_default():
+    with context.graph_mode(), \
+        ops.Graph().as_default(), \
+        ops.device("/device:CPU:0"):
       one = constant_op.constant(1)
       two = constant_op.constant(2)
-      v = values.DistributedValues((one, two))
+      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
+      v = values.DistributedValues(device_map, (one, two))
+      self.assertEqual(two, v.get("/device:GPU:0"))
       self.assertEqual(one, v.get())
-      with distribute_lib.ReplicaContext(None, 1):
-        self.assertEqual(two, v.get())
+      with self.assertRaises(ValueError):
+        self.assertIsNone(v.get("/device:GPU:2"))
+
+  def testCanonicalization(self):
+    canonical_cpu = ("/job:localhost/replica:0/task:0/device:CPU:0",)
+    v = values.DistributedValues(values.SingleDeviceMap(""), (42,))
+    self.assertEqual(canonical_cpu, v.devices)
+    v = values.DistributedValues(values.SingleDeviceMap("/device:CPU:0"), (42,))
+    self.assertEqual(canonical_cpu, v.devices)
+    v = values.DistributedValues(values.SingleDeviceMap("/cpu:0"), (42,))
+    self.assertEqual(canonical_cpu, v.devices)
+    v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,))
+    self.assertEqual(canonical_cpu, v.devices)
 
   def testIsTensorLike(self):
-    with context.graph_mode(), ops.Graph().as_default():
+    with context.graph_mode(), \
+         ops.Graph().as_default(), \
+         ops.device("/device:CPU:0"):
       one = constant_op.constant(1)
       two = constant_op.constant(2)
-      v = values.DistributedValues((one, two))
+      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
+      v = values.DistributedValues(device_map, (one, two))
+      self.assertEqual(two, v.get("/device:GPU:0"))
+      self.assertEqual(one, v.get())
       self.assertTrue(v.is_tensor_like)
       self.assertTrue(tensor_util.is_tensor(v))
 
   def testIsTensorLikeWithAConstant(self):
-    with context.graph_mode(), ops.Graph().as_default():
+    with context.graph_mode(), \
+         ops.Graph().as_default(), \
+         ops.device("/device:CPU:0"):
       one = constant_op.constant(1)
       two = 2.0
-      v = values.DistributedValues((one, two))
+      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
+      v = values.DistributedValues(device_map, (one, two))
+      self.assertEqual(two, v.get("/device:GPU:0"))
+      self.assertEqual(one, v.get())
       self.assertFalse(v.is_tensor_like)
       self.assertFalse(tensor_util.is_tensor(v))
 
@@ -92,59 +120,62 @@ class DistributedDelegateTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testGetAttr(self):
-    class Foo(object):
+    with ops.device("/device:CPU:0"):
 
-      def __init__(self, x):
-        self.x = x
+      class Foo(object):
 
-    v = values.DistributedDelegate((Foo(7), Foo(8)))
-    self.assertEqual(7, v.x)
-    with self.assertRaises(AttributeError):
-      _ = v.y
+        def __init__(self, x):
+          self.x = x
+
+      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
+      v = values.DistributedDelegate(device_map, (Foo(7), Foo(8)))
+      self.assertEqual(7, v.x)
+      with self.assertRaises(AttributeError):
+        _ = v.y
 
   @test_util.run_in_graph_and_eager_modes
   def testOperatorOverride(self):
-    v = values.DistributedDelegate((7, 8))
-    # v should act like int(7).
-    self.assertEqual(8, v + 1)
-    self.assertEqual(10, 3 + v)
-    self.assertEqual(14, v + v)
-    self.assertEqual(5, v - 2)
-    self.assertEqual(6, 13 - v)
-    self.assertEqual(0, v - v)
-    self.assertEqual(14, v * 2)
-    self.assertEqual(21, 3 * v)
-    self.assertEqual(49, v * v)
-    self.assertEqual(3.5, v / 2)
-    self.assertEqual(1.5, 10.5 / v)
-    self.assertEqual(3, v // 2)
-    self.assertEqual(2, 15 // v)
-    self.assertEqual(1, v % 2)
-    self.assertEqual(2, 16 % v)
-    # pylint: disable=g-generic-assert
-    self.assertTrue(v < 12)
-    self.assertTrue(v <= 12)
-    self.assertFalse(v > 12)
-    self.assertFalse(v >= 12)
-    self.assertFalse(12 < v)
-    self.assertFalse(12 <= v)
-    self.assertTrue(12 > v)
-    self.assertTrue(12 >= v)
-    # pylint: enable=g-generic-assert
-    self.assertEqual(3, v & 3)
-    self.assertEqual(3, 11 & v)
-    self.assertEqual(15, v | 8)
-    self.assertEqual(23, 16 | v)
-    self.assertEqual(4, v ^ 3)
-    self.assertEqual(12, 11 ^ v)
-    self.assertEqual(343, pow(v, 3))
-    self.assertEqual(3, pow(v, 3, 10))
-    self.assertEqual(128, pow(2, v))
-    self.assertEqual(-7, -v)
-    self.assertEqual(~7, ~v)
-    self.assertEqual(7, abs(v))
-    with self.assertRaises(TypeError):
-      _ = v[2]
+    with ops.device("/device:CPU:0"):
+      device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
+      v = values.DistributedDelegate(device_map, (7, 8))
+      # v should act like int(7).
+      self.assertEqual(8, v + 1)
+      self.assertEqual(10, 3 + v)
+      self.assertEqual(14, v + v)
+      self.assertEqual(5, v - 2)
+      self.assertEqual(6, 13 - v)
+      self.assertEqual(0, v - v)
+      self.assertEqual(14, v * 2)
+      self.assertEqual(21, 3 * v)
+      self.assertEqual(49, v * v)
+      self.assertEqual(3.5, v / 2)
+      self.assertEqual(1.5, 10.5 / v)
+      self.assertEqual(3, v // 2)
+      self.assertEqual(2, 15 // v)
+      self.assertEqual(1, v % 2)
+      self.assertEqual(2, 16 % v)
+      self.assertTrue(v < 12)
+      self.assertTrue(v <= 12)
+      self.assertFalse(v > 12)
+      self.assertFalse(v >= 12)
+      self.assertFalse(12 < v)
+      self.assertFalse(12 <= v)
+      self.assertTrue(12 > v)
+      self.assertTrue(12 >= v)
+      self.assertEqual(3, v & 3)
+      self.assertEqual(3, 11 & v)
+      self.assertEqual(15, v | 8)
+      self.assertEqual(23, 16 | v)
+      self.assertEqual(4, v ^ 3)
+      self.assertEqual(12, 11 ^ v)
+      self.assertEqual(343, pow(v, 3))
+      self.assertEqual(3, pow(v, 3, 10))
+      self.assertEqual(128, pow(2, v))
+      self.assertEqual(-7, -v)
+      self.assertEqual(~7, ~v)
+      self.assertEqual(7, abs(v))
+      with self.assertRaises(TypeError):
+        _ = v[2]
 
 
 def _device_str(d):
@@ -154,15 +185,15 @@ def _device_str(d):
 def _nested_value(d):
   return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
 
-
 def _make_mirrored_val(init_val=5.0):
   v = []
   devices = ["/device:GPU:0", "/device:CPU:0"]
   for d, _ in zip(devices, ["v", "v/replica"]):
     with ops.device(d):
       v.append(constant_op.constant(init_val))
-  return values.Mirrored(v)
-
+  device_map = values.ReplicaDeviceMap(devices)
+  mirrored = values.Mirrored(device_map, v)
+  return mirrored
 
 def _make_mirrored():
   v = []
@@ -171,20 +202,29 @@ def _make_mirrored():
     with ops.device(d):
       v.append(variable_scope.get_variable(
           name=n, initializer=init, use_resource=True))
-  mirrored = values.MirroredVariable(
-      None, v, variable_scope.VariableAggregation.SUM)
-  return mirrored
+  device_map = values.ReplicaDeviceMap(devices)
+  mirrored = values.MirroredVariable(None, device_map, v,
+                                     variable_scope.VariableAggregation.SUM)
+  return v, device_map, mirrored
 
 
 class RegroupAndSelectDeviceTest(test.TestCase):
 
   def _is_per_replica(self, result, expected, klass=values.PerReplica):
     self.assertIsInstance(result, klass)
-    for i, exp in enumerate(expected):
-      self.assertEqual(exp, result.values[i])
+    # We canonicalize the devices to match the device strings returned
+    # by PerReplica, which also does device string canonicalization.
+    devices = [device_util.canonicalize(_device_str(i))
+               for i in range(len(expected))]
+    self.assertEqual(set(devices), set(result.devices))
+    for i, d in enumerate(devices):
+      self.assertEqual(expected[i], result.get(d))
+      self.assertEqual(expected[i], result.get(_device_str(i)))
 
   def testNested(self):
-    result = values.regroup((_nested_value("1"), _nested_value("2")))
+    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
+    result = values.regroup(device_map,
+                            (_nested_value("1"), _nested_value("2")))
     self.assertIsInstance(result, tuple)
     self.assertEqual(3, len(result))
     self._is_per_replica(result[0], ["a1", "a2"])
@@ -207,14 +247,16 @@ class RegroupAndSelectDeviceTest(test.TestCase):
                      values.select_replica(1, result))
     # select_device_mirrored() should fail due to non-mirrored values
     with self.assertRaises(TypeError):
-      values.select_replica_mirrored(0, result)
+      values.select_device_mirrored(_device_str(0), result)
     with self.assertRaises(TypeError):
-      values.select_replica_mirrored(1, result)
+      values.select_device_mirrored(_device_str(1), result)
 
   def testWrapClass(self):
     # Normally a mirrored value would be the same across devices, but
     # for a test it is convenient to be able to tell the values apart.
-    result = values.regroup((_nested_value("1"), _nested_value("2")),
+    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
+    result = values.regroup(device_map,
+                            (_nested_value("1"), _nested_value("2")),
                             values.Mirrored)
     self.assertIsInstance(result, tuple)
     self.assertEqual(3, len(result))
@@ -238,12 +280,13 @@ class RegroupAndSelectDeviceTest(test.TestCase):
                      values.select_replica(1, result))
     # Values are marked as mirrored, so select_device_mirrored() is allowed.
     self.assertEqual(_nested_value("1"),
-                     values.select_replica_mirrored(0, result))
+                     values.select_device_mirrored(_device_str(0), result))
     self.assertEqual(_nested_value("2"),
-                     values.select_replica_mirrored(1, result))
+                     values.select_device_mirrored(_device_str(1), result))
 
   def testWrapAListOfTwoTuples(self):
-    result = values.regroup([("1", "2"), ("3", "4")])
+    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
+    result = values.regroup(device_map, [("1", "2"), ("3", "4")])
     self.assertIsInstance(result, tuple)
     self.assertEqual(2, len(result))
     self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
@@ -252,13 +295,14 @@ class RegroupAndSelectDeviceTest(test.TestCase):
   def testMirroredContainer(self):
     if context.num_gpus() < 1 and context.executing_eagerly():
       self.skipTest("A GPU is not available for this test in eager mode.")
-    mirrored = _make_mirrored()
-    result = values.regroup(mirrored.values)
+    v, device_map, mirrored = _make_mirrored()
+    result = values.regroup(device_map, v)
     self.assertIs(mirrored, result)
 
   def testSameId(self):
     foo = object()
-    result = values.regroup((("a", foo), ("b", foo)))
+    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
+    result = values.regroup(device_map, (("a", foo), ("b", foo)))
     self.assertIsInstance(result, tuple)
     self.assertEqual(2, len(result))
     self._is_per_replica(result[0], ["a", "b"])
@@ -277,7 +321,8 @@ class RegroupAndSelectDeviceTest(test.TestCase):
     self.assertIs(foo, result_1[1])
 
   def testOneDevice(self):
-    result = values.regroup((_nested_value("1"),))
+    device_map = values.ReplicaDeviceMap((_device_str(0),))
+    result = values.regroup(device_map, (_nested_value("1"),))
     # On one device regroup() and select_replica() are basically identity.
     self.assertEqual(_nested_value("1"), result)
     self.assertEqual(_nested_value("1"),
@@ -288,9 +333,10 @@ class RegroupAndSelectDeviceTest(test.TestCase):
     with ops.device(d):
       v = variable_scope.get_variable(
           name="v", initializer=1., use_resource=True)
-    mirrored = values.MirroredVariable(None, (v,),
+      device_map = values.ReplicaDeviceMap((d,))
+    mirrored = values.MirroredVariable(None, device_map, (v,),
                                        variable_scope.VariableAggregation.SUM)
-    result = values.regroup((v,))
+    result = values.regroup(device_map, (v,))
     self.assertIs(mirrored, result)
 
   def testNamedTuple(self):
@@ -310,6 +356,7 @@ class RegroupAndSelectDeviceTest(test.TestCase):
             scaffold=scaffold or Scaffold())
 
     with context.graph_mode(), ops.Graph().as_default():
+      devices = []
       created_estimator_specs = []
 
       for device_id in range(3):
@@ -317,21 +364,25 @@ class RegroupAndSelectDeviceTest(test.TestCase):
             mode=mode_keys.EstimatorModeKeys.TRAIN,
             loss=constant_op.constant(device_id / 2),
             train_op=array_ops.identity(constant_op.constant(device_id)))
+        devices.append(_device_str(device_id))
         created_estimator_specs.append(spec)
 
-      merged_estimator_spec = values.regroup(created_estimator_specs)
+      device_map = values.ReplicaDeviceMap(devices)
+      merged_estimator_spec = values.regroup(
+          device_map, created_estimator_specs)
 
       self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
       self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
                        merged_estimator_spec.mode)
       for device_id in range(3):
+        d = _device_str(device_id)
         self.assertEqual(created_estimator_specs[device_id].loss,
-                         merged_estimator_spec.loss.values[device_id])
+                         merged_estimator_spec.loss.get(d))
         self.assertEqual(created_estimator_specs[device_id].train_op,
-                         merged_estimator_spec.train_op.values[device_id])
+                         merged_estimator_spec.train_op.get(d))
         # Scaffold is populated by `EstimatorSpec.__new__`.
         self.assertEqual(created_estimator_specs[device_id].scaffold,
-                         merged_estimator_spec.scaffold.values[device_id])
+                         merged_estimator_spec.scaffold.get(d))
         self.assertIsInstance(created_estimator_specs[device_id].scaffold,
                               Scaffold)
         # Also test that we can undo the merge using select_replica()
@@ -350,26 +401,28 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
     if context.num_gpus() < 1 and context.executing_eagerly():
       self.skipTest("A GPU is not available for this test in eager mode.")
 
-    mirrored = _make_mirrored()
-    v = mirrored.values[0]
-    self.assertEqual(v.name, mirrored.name)
-    self.assertEqual(v.dtype, mirrored.dtype)
-    self.assertEqual(v.shape, mirrored.shape)
+    v, _, mirrored = _make_mirrored()
+
+    self.assertEqual(v[0].name, mirrored.name)
+    self.assertEqual(v[0].dtype, mirrored.dtype)
+    self.assertEqual(v[0].shape, mirrored.shape)
 
   @test_util.run_in_graph_and_eager_modes(config=config)
   def testVariableOnAnotherDevice(self):
     v = variable_scope.get_variable(
         name="v", initializer=[1.], use_resource=True)
-    mirrored = values.MirroredVariable(
-        None, (v,), variable_scope.VariableAggregation.MEAN)
+    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
+    mirrored = values.MirroredVariable(None, device_map, (v,),
+                                       variable_scope.VariableAggregation.MEAN)
 
     self.assertEqual(v.name, mirrored.name)
     self.assertEqual(v.dtype, mirrored.dtype)
     self.assertEqual(v.shape, mirrored.shape)
 
-  def _assign_mirrored(self, v, new):
-    for var, n in zip(v.values, new):
-      self.evaluate(var.assign(n))
+  def _assign_mirrored(self, devices, v, new):
+    for d, var, n in zip(devices, v, new):
+      with ops.device(d):
+        self.evaluate(var.assign(n))
 
   def _save_return_saver(self, sess, var):
     saver = saver_lib.Saver(var_list=[var])
@@ -392,17 +445,17 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
       self.skipTest("A GPU is not available for this test in eager mode.")
 
     with self.cached_session(config=self.config) as sess:
-      mirrored = _make_mirrored()
-      v = mirrored.values
+      v, device_map, mirrored = _make_mirrored()
+      devices = device_map.all_devices
 
       # Overwrite the initial values.
-      self._assign_mirrored(mirrored, [3., 4.])
+      self._assign_mirrored(devices, v, [3., 4.])
 
       # Saves the current value of v[0], 3.
       save_path, saver = self._save_return_saver(sess, mirrored)
 
       # Change the values between save and restore.
-      self._assign_mirrored(mirrored, [5., 6.])
+      self._assign_mirrored(devices, v, [5., 6.])
 
       # Restores the saved value of 3. to both variables.
       saver.restore(sess, save_path)
@@ -411,16 +464,17 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
   def _save_mirrored(self):
     """Save variables with mirroring, returns save_path."""
     with self.session(graph=ops.Graph()) as sess:
-      mirrored = _make_mirrored()
+      v, device_map, mirrored = _make_mirrored()
+      devices = device_map.all_devices
 
       # Overwrite the initial values.
-      self._assign_mirrored(mirrored, [3., 4.])
+      self._assign_mirrored(devices, v, [3., 4.])
 
       # Saves the current value of v[0], 3.
       save_path = self._save(sess, mirrored)
 
       # Change the values between save and restore.
-      self._assign_mirrored(mirrored, [5., 6.])
+      self._assign_mirrored(devices, v, [5., 6.])
     return save_path
 
   def _save_normal(self):
@@ -456,11 +510,11 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
   def _restore_mirrored(self, save_path):
     """Restore to variables with mirroring in a fresh graph."""
     with self.session(graph=ops.Graph()) as sess:
-      mirrored = _make_mirrored()
-      v = mirrored.values
+      v, device_map, mirrored = _make_mirrored()
+      devices = device_map.all_devices
 
       # Overwrite the initial values.
-      self._assign_mirrored(mirrored, [7., 8.])
+      self._assign_mirrored(devices, v, [7., 8.])
 
       # Restores the saved value of 3. to both variables.
       saver = saver_lib.Saver(var_list=[mirrored])
@@ -518,7 +572,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
         v = variable_scope.get_variable(
             name="v", initializer=1., use_resource=True)
       mirrored = values.MirroredVariable(
-          distribution, (v,), variable_scope.VariableAggregation.MEAN)
+          distribution, values.ReplicaDeviceMap(("/device:GPU:0",)), (v,),
+          variable_scope.VariableAggregation.MEAN)
       sess.run(variables_lib.global_variables_initializer())
       sess.run({"complicated": mirrored})
 
@@ -689,6 +744,7 @@ def _make_replica_local(method, strategy=None):
   else:
     devices = strategy.extended.worker_devices
 
+  device_map = values.ReplicaDeviceMap(devices)
   v = []
   for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
     with ops.device(d):
@@ -699,7 +755,7 @@ def _make_replica_local(method, strategy=None):
     var_cls = values.TPUSyncOnReadVariable
   else:
     var_cls = values.SyncOnReadVariable
-  replica_local = var_cls(strategy, v, method)
+  replica_local = var_cls(strategy, device_map, v, method)
   return v, replica_local
 
 
@@ -721,6 +777,20 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
     self.assertEqual(variable_scope.VariableAggregation.SUM,
                      replica_local.aggregation)
 
+  @test_util.run_in_graph_and_eager_modes(config=config)
+  def testVariableOnAnotherDevice(self):
+    v = variable_scope.get_variable(
+        name="v", initializer=[1.], use_resource=True)
+    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
+    replica_local = values.SyncOnReadVariable(
+        None, device_map, (v,), variable_scope.VariableAggregation.MEAN)
+
+    self.assertEqual(v.name, replica_local.name)
+    self.assertEqual(v.dtype, replica_local.dtype)
+    self.assertEqual(v.shape, replica_local.shape)
+    self.assertEqual(variable_scope.VariableAggregation.MEAN,
+                     replica_local.aggregation)
+
   def testTensorConversion(self):
     with context.graph_mode():
       _, replica_local = _make_replica_local(
@@ -742,8 +812,9 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
 
     v = variable_scope.get_variable(
         name="v", initializer=[1.], use_resource=True)
+    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
     replica_local = values.SyncOnReadVariable(
-        None, (v,), variable_scope.VariableAggregation.MEAN)
+        None, device_map, (v,), variable_scope.VariableAggregation.MEAN)
     self.assertEqual(2., self.evaluate(add1(replica_local)))
 
 
@@ -1100,7 +1171,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
     vals = self.evaluate(v[0].values)
     self.assertAllEqual(vals[0], vals[1])
 
-
 class MirroredTest(test.TestCase):
 
   def testAddOp(self):
@@ -1121,39 +1191,49 @@ class MirroredTest(test.TestCase):
 class PerReplicaTest(test.TestCase, parameterized.TestCase):
 
   def testTypeSpec(self):
+    device_map = values.SingleDeviceMap("CPU")
     vals = (constant_op.constant(1.),)
-    per_replica = values.PerReplica(vals)
+    per_replica = values.PerReplica(device_map, vals)
 
     spec = per_replica._type_spec
     self.assertEqual(spec._value_specs,
                      (tensor_spec.TensorSpec([], dtypes.float32),))
+    self.assertEqual(spec._device_map, per_replica.device_map)
+    self.assertEqual(spec._logical_device, per_replica.logical_device)
 
   def testTypeSpecRoundTrip(self):
+    device_map = values.SingleDeviceMap("CPU")
     vals = (constant_op.constant(1.),)
-    per_replica = values.PerReplica(vals)
+    per_replica = values.PerReplica(device_map, vals)
 
     spec = per_replica._type_spec
     tensor_list = spec._to_components(per_replica)
     reconstructed = spec._from_components(tensor_list)
 
+    self.assertEqual(per_replica.device_map, reconstructed.device_map)
+    self.assertEqual(per_replica.logical_device, reconstructed.logical_device)
     self.assertAllEqual(per_replica.values, reconstructed.values)
 
   def testTypeSpecNest(self):
+    device_map = values.ReplicaDeviceMap(["CPU:0", "CPU:1"])
     vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
-    per_replica = values.PerReplica(vals)
+    per_replica = values.PerReplica(device_map, vals)
 
     # Note: nest.map_structutre exercises nest.flatten and
     # nest.pack_sequence_as.
-    result = nest.map_structure(
-        lambda t: t + 10, per_replica, expand_composites=True)
+    result = nest.map_structure(lambda t: t + 10, per_replica,
+                                expand_composites=True)
 
+    self.assertEqual(per_replica.device_map, result.device_map)
+    self.assertEqual(per_replica.logical_device, result.logical_device)
     self.assertLen(result.values, 2)
     self.assertAllEqual(result.values[0], 11.)
     self.assertAllEqual(result.values[1], [15., 16.0])
 
   @test_util.run_in_graph_and_eager_modes
   def testIsGraphTensor(self):
-    per_replica = values.PerReplica((constant_op.constant(1.),))
+    per_replica = values.PerReplica(values.SingleDeviceMap("CPU"),
+                                    (constant_op.constant(1.),))
     for t in nest.flatten(per_replica, expand_composites=True):
       self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
 
@@ -1165,7 +1245,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
       traces.append(None)  # Only happens on trace.
       return x
 
-    per_replica = values.PerReplica((constant_op.constant(1.),))
+    per_replica = values.PerReplica(
+        values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
 
     # Trace once.
     f(per_replica)
@@ -1181,11 +1262,14 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
       output = f(per_replica)
       self.assertIsInstance(output, values.PerReplica)
       self.assertAllEqual(output._values, per_replica._values)
+      self.assertAllEqual(output._device_map, per_replica._device_map)
+      self.assertAllEqual(output._logical_device, per_replica._logical_device)
       self.assertEmpty(traces)  # Make sure we're not re-tracing `f`.
 
   def testFunctionCanReturnPerReplica(self):
     f = def_function.function(lambda x: x)
-    x = values.PerReplica((constant_op.constant(1.),))
+    x = values.PerReplica(
+        values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
     y = f(x)
     self.assertIsNot(x, y)
     nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
@@ -1193,32 +1277,40 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testCondWithTensorValues(self):
-    per_replica_1 = values.PerReplica((constant_op.constant("a"),))
-    per_replica_2 = values.PerReplica((constant_op.constant(["b", "c"]),))
+    device_map = values.SingleDeviceMap("CPU")
+    per_replica_1 = values.PerReplica(device_map, (constant_op.constant("a"),))
+    per_replica_2 = values.PerReplica(device_map,
+                                      (constant_op.constant(["b", "c"]),))
     condition = array_ops.placeholder_with_default(True, [])
 
     result = control_flow_ops.cond(
         condition, lambda: per_replica_1, lambda: per_replica_2)
 
+    self.assertEqual(per_replica_1.device_map, result.device_map)
+    self.assertEqual(per_replica_1.logical_device, result.logical_device)
     self.assertLen(result.values, 1)
     self.assertAllEqual(result.values[0], "a")
 
   @test_util.run_in_graph_and_eager_modes
   def testCondWithValuesConvertibleToTensor(self):
-    per_replica_1 = values.PerReplica(("a",))
-    per_replica_2 = values.PerReplica(("b",))
+    device_map = values.SingleDeviceMap("CPU")
+    per_replica_1 = values.PerReplica(device_map, ("a",))
+    per_replica_2 = values.PerReplica(device_map, ("b",))
     condition = array_ops.placeholder_with_default(True, [])
 
     result = control_flow_ops.cond(
         condition, lambda: per_replica_1, lambda: per_replica_2)
 
+    self.assertEqual(per_replica_1.device_map, result.device_map)
+    self.assertEqual(per_replica_1.logical_device, result.logical_device)
     self.assertLen(result.values, 1)
     self.assertAllEqual(result.values[0], "a")
 
   @test_util.build_as_function_and_v1_graph
   def testCondWithValuesNotConvertibleToTensor(self):
-    per_replica_1 = values.PerReplica(({"a"},))
-    per_replica_2 = values.PerReplica(({"b", "c"},))
+    device_map = values.SingleDeviceMap("CPU")
+    per_replica_1 = values.PerReplica(device_map, (set(["a"]),))
+    per_replica_2 = values.PerReplica(device_map, (set(["b", "c"]),))
     condition = array_ops.placeholder(dtypes.bool, [])
 
     with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
@@ -1226,5 +1318,88 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
           condition, lambda: per_replica_1, lambda: per_replica_2)
 
 
+class WorkerDeviceMapTest(test.TestCase, parameterized.TestCase):
+
+  class ReplicaContext(object):
+
+    def __init__(self, replica_id_in_sync_group):
+      self.replica_id_in_sync_group = replica_id_in_sync_group
+
+  def testBasic(self):
+    devices = [
+        "/job:worker/replica:0/task:0/device:CPU:0",
+        "/job:worker/replica:0/task:2/device:CPU:0"
+    ]
+    device_map = values.WorkerDeviceMap(devices, 1)
+    self.assertAllEqual(devices, device_map.all_devices)
+
+    # pylint:disable=pointless-statement
+    with self.assertRaisesWithPredicateMatch(
+        ValueError, "`WorkerDeviceMap` is not indexed by replicas"):
+      device_map.devices_by_replica
+
+    self.assertEqual(1, device_map.num_logical_devices)
+
+    self.assertEqual(2, device_map.num_replicas_in_graph)
+
+    self.assertEqual(0, device_map.logical_device_from_values(["a", "b"]))
+
+    self.assertAllEqual(devices, device_map.logical_to_actual_devices(0))
+
+    replica_context = WorkerDeviceMapTest.ReplicaContext(1)
+    self.assertEqual(
+        "b", device_map.select_for_current_replica(["a", "b"], replica_context))
+
+    with self.assertRaisesWithPredicateMatch(
+        ValueError, "`WorkerDeviceMap` not indexed by replicas"):
+      device_map.replica_for_device(devices[1])
+
+    self.assertEqual("b", device_map.select_for_device(["a", "b"], devices[1]))
+
+    with self.assertRaisesWithPredicateMatch(
+        ValueError, "WorkerDeviceMap not indexed by replicas"):
+      device_map.is_device_in_replica(devices[1], 1)
+
+    self.assertEqual(
+        "WorkerDeviceMap(('/job:worker/replica:0/task:0/device:CPU:0', "
+        "'/job:worker/replica:0/task:2/device:CPU:0'), "
+        "num_replicas_per_worker=1)", repr(device_map))
+
+  def testMultipleReplicasPerWorker(self):
+    devices = [
+        "/job:worker/replica:0/task:0/device:CPU:0",
+        "/job:worker/replica:0/task:2/device:CPU:0"
+    ]
+    device_map = values.WorkerDeviceMap(devices, 2)
+
+    replica_context = WorkerDeviceMapTest.ReplicaContext(3)
+    self.assertEqual(
+        "b", device_map.select_for_current_replica(["a", "b"], replica_context))
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+          ],
+          mode=["graph", "eager"]))
+  def testExperimentalLocalResultsOrder(self, distribution):
+    # Create 2 devices in the device map, where the alphabetical order and the
+    # actual order of devices are different.
+    device_map = values.ReplicaDeviceMap(["CPU:2", "CPU:10"])
+    vals = (
+        constant_op.constant(1.),
+        constant_op.constant([5., 6.0]),
+    )
+    per_replica = values.PerReplica(device_map, vals)
+    results = self.evaluate(
+        distribution.experimental_local_results(per_replica))
+
+    # We expect the outputs order the same as the inputs order.
+    self.assertLen(results, 2)
+    self.assertAllEqual(1.0, results[0])
+    self.assertAllEqual([5., 6.], results[1])
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/keras/distribute/keras_utils_test.py b/tensorflow/python/keras/distribute/keras_utils_test.py
index 3529935dd51..bf328e447c1 100644
--- a/tensorflow/python/keras/distribute/keras_utils_test.py
+++ b/tensorflow/python/keras/distribute/keras_utils_test.py
@@ -197,8 +197,9 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
     with self.cached_session():
       a = constant_op.constant([1, 2], shape=(1, 2))
       b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2))
-      x = values.DistributedValues((a, b))
-      y = values.DistributedValues((a, a))
+      device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0'))
+      x = values.DistributedValues(device_map, (a, b))
+      y = values.DistributedValues(device_map, (a, a))
       # Removed device and input tensor shape details from the error message
       # since the order of the device and the corresponding input tensor shape
       # is not deterministic over different runs.
@@ -221,8 +222,9 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
     with self.cached_session():
       a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
       b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
-      x = values.DistributedValues((a, b))
-      y = values.DistributedValues((a, a))
+      device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0'))
+      x = values.DistributedValues(device_map, (a, b))
+      y = values.DistributedValues(device_map, (a, a))
       # Removed device and input tensor dtype details from the error message
       # since the order of the device and the corresponding input tensor dtype
       # is not deterministic over different runs.
diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py
index 62b777b8188..71594de1058 100644
--- a/tensorflow/python/module/module_test.py
+++ b/tensorflow/python/module/module_test.py
@@ -247,10 +247,13 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
     self.assertEqual(len(m.child.child.trainable_variables), 0)
 
   def test_supports_distributed_variables(self):
+    device_map = distributed_values.SingleDeviceMap("/CPU:0")
     mirrored = distributed_values.MirroredVariable(
-        None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
+        None, device_map, [variables.Variable(1.)],
+        variables.VariableAggregation.SUM)
     tpu = distributed_values.TPUMirroredVariable(
         strategy=None,
+        device_map=device_map,
         values=[variables.Variable(42.)],
         aggregation=None)
     aggregating = distributed_values.AggregatingVariable(
diff --git a/tensorflow/python/ops/stateful_random_ops_test.py b/tensorflow/python/ops/stateful_random_ops_test.py
index b68753617d6..499698b7d57 100644
--- a/tensorflow/python/ops/stateful_random_ops_test.py
+++ b/tensorflow/python/ops/stateful_random_ops_test.py
@@ -727,7 +727,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
     devices = ["cpu:0", "cpu:1"]
     strat = MirroredStrategy(devices=devices)
     # Use `PerReplica` to specify which `gen` is sent to which replica
-    gens = dist_values.PerReplica([[g] for g in gens])
+    gens = dist_values.PerReplica(
+        device_map=dist_values.ReplicaDeviceMap(devices),
+        values=[[g] for g in gens])
     with strat.scope():
       def f(gen):
         t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py
index f6835fce76d..2af31a9dd58 100644
--- a/tensorflow/python/tpu/tpu.py
+++ b/tensorflow/python/tpu/tpu.py
@@ -259,7 +259,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
     self._pivot = pivot
     self._replicated_vars = {}
 
-  def get_replicated_var_handle(self, name, vars_, is_mirrored=False):
+  def get_replicated_var_handle(self,
+                                name,
+                                vars_,
+                                device_map=None,
+                                is_mirrored=False):
     """Returns a variable handle for replicated TPU variable 'var'.
 
     This is a method used by an experimental replicated variable implementation
@@ -268,6 +272,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
     Args:
       name: The common name of the variable.
       vars_: The replicated TPU variables.
+      device_map: The DeviceMap used to create the variables if it is a
+        TPUMirroredVariable.
       is_mirrored: Whether the variables are mirrored, which guarantees the
         values in each replica are always the same.
 
@@ -281,20 +287,15 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
     if handle is not None:
       return handle
 
-    if device_assignment is not None:
-      job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
-
-      tpu_devices = set()
+    replicated_vars = []
+    if device_assignment is not None and device_map is not None:
+      job_name = pydev.DeviceSpec.from_string(device_map.all_devices[0]).job
       for replica_id in range(device_assignment.num_replicas):
-        for logical_core in range(device_assignment.num_cores_per_replica):
-          tpu_devices.add(
-              device_util.canonicalize(
-                  device_assignment.tpu_device(
-                      replica=replica_id,
-                      logical_core=logical_core,
-                      job=job_name)))
-
-      replicated_vars = [v for v in vars_ if v.device in tpu_devices]
+        tpu_device = device_assignment.tpu_device(
+            replica=replica_id, logical_core=0, job=job_name)
+        tpu_device = device_util.canonicalize(tpu_device)
+        replica = device_map.replica_for_device(tpu_device)
+        replicated_vars.append(vars_[replica])
     else:
       replicated_vars = vars_
 
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index f1a31d01dd4..3af76873051 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -768,7 +768,7 @@ class Optimizer(
       # pylint: enable=protected-access
       mirrored_slot = named_slots.get(key, None)
       if mirrored_slot is None: return None
-      return mirrored_slot._get_closest()  # pylint: disable=protected-access
+      return mirrored_slot.get(device=var.device)
 
     return named_slots.get(_var_key(var), None)
 

From baee6674096e911e48c0218edc7cf93b3d34f015 Mon Sep 17 00:00:00 2001
From: Stephan Herhut 
Date: Mon, 2 Dec 2019 04:48:13 -0800
Subject: [PATCH 136/279] Fix computation of strides value in memref
 descriptor.

The strides in the memref descriptor encode the multipliers for
indexing and not subsampling strides, as the old code assumed.
Now, the strides are intialized properly.

PiperOrigin-RevId: 283317479
Change-Id: I1679a2569eb56c683c1a5a6f9d59ada7b02a3552
---
 .../xla/service/mlir_gpu/mlir_compiler.cc       | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
index ddfa28a9b42..b035a8ddcb5 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
@@ -344,15 +344,26 @@ Status InsertBufferLoadPreduleIntoKernel(
             loc, entry_type, builder.getI64IntegerAttr(extent.value()));
         builder.create(loc, extentValue, shapeEntryPtr);
       }
-      // Finally, fill the strides with all ones.
+      // Finally, fill the strides.
+      // TODO(b/137624192): Take assigned layout into account.
       entry_type = struct_type.getStructElementType(4).getArrayElementType();
-      for (int64 idx = 0; idx < shape.rank(); ++idx) {
+      Value* accumulator = nullptr;
+      for (int64 idx = shape.rank() - 1; idx >= 0; --idx) {
         auto indexValue = builder.create(
             loc, offset_type, builder.getI64IntegerAttr(idx));
         auto strideEntryPtr = builder.create(
             loc, entry_type, descPtr,
             llvm::ArrayRef{zero, strideIndex, indexValue});
-        builder.create(loc, one, strideEntryPtr);
+        if (accumulator) {
+          auto strideValue = builder.create(
+              loc, entry_type,
+              builder.getI64IntegerAttr(shape.dimensions(idx + 1)));
+          accumulator = builder.create(
+              loc, entry_type, accumulator, strideValue);
+        } else {
+          accumulator = one;
+        }
+        builder.create(loc, accumulator, strideEntryPtr);
       }
       // Now we can use the descriptor instead of the original argument.
       value->replaceAllUsesWith(descPtr);

From e99f359a5317ccb52ef5999e4eb6098ccb3909c0 Mon Sep 17 00:00:00 2001
From: Chris Jones 
Date: Mon, 2 Dec 2019 05:49:33 -0800
Subject: [PATCH 137/279] Add `Hash` method to `XlaComputation` in the Python
 bindings.

PiperOrigin-RevId: 283323918
Change-Id: Id017914f43680365c0584f3d46dfef200ce77620
---
 tensorflow/compiler/xla/python/xla.cc          | 14 +++++++++++++-
 tensorflow/compiler/xla/python/xla_client.py   |  3 +++
 .../compiler/xla/python/xla_client_test.py     | 18 ++++++++++++++++++
 3 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index 4398561c5c4..054c1da9e03 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -110,6 +110,17 @@ StatusOr GetComputationHloDotGraph(
                      RenderedGraphFormat::kDot);
 }
 
+// Hashes the HLO module.
+StatusOr HashComputation(const XlaComputation& computation) {
+  TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
+                      HloModule::CreateModuleConfigFromProto(
+                          computation.proto(), GetDebugOptionsFromFlags()));
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr hlo_module,
+      HloModule::CreateFromProto(computation.proto(), module_config));
+  return hlo_module->Hash();
+}
+
 // Registers a 'fn_capsule' as a CPU custom call target.
 // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
 // with name "xla._CUSTOM_CALL_TARGET".
@@ -577,7 +588,8 @@ PYBIND11_MODULE(xla_extension, m) {
       .def("GetProgramShape", &XlaComputation::GetProgramShape)
       .def("GetSerializedProto", &GetComputationSerializedProto)
       .def("GetHloText", &GetComputationHloText)
-      .def("GetHloDotGraph", &GetComputationHloDotGraph);
+      .def("GetHloDotGraph", &GetComputationHloDotGraph)
+      .def("Hash", &HashComputation);
 
   py::class_(m, "XlaOp");
 
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 65db35a6988..c8f66f704d7 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -594,6 +594,9 @@ class Computation(object):
   def GetReturnValueShape(self):
     return self._c_computation.GetProgramShape().result_shape()
 
+  def Hash(self):
+    return self._c_computation.Hash()
+
 
 # An Executable is a C++ class that duck types with the following API:
 # class Executable(object):
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index db1c01293ab..f490a05e25d 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -112,6 +112,24 @@ class ComputationPrinting(absltest.TestCase):
     self.assertTrue(hlo_dot_graph.startswith("digraph "))
 
 
+class ComputationHashTest(absltest.TestCase):
+
+  def testHash(self):
+    builder0 = xla_client.ComputationBuilder("computation0")
+    p0 = builder0.ParameterFromNumpy(np.float32(0))
+    p1 = builder0.ParameterFromNumpy(np.zeros((4,), np.float32))
+    builder0.Mul(p0, p1)
+    computation0 = builder0.Build()
+
+    builder1 = xla_client.ComputationBuilder("computation1")
+    p0 = builder1.ParameterFromNumpy(np.float32(0))
+    p1 = builder1.ParameterFromNumpy(np.zeros((4,), np.float32))
+    builder1.Mul(p0, p1)
+    computation1 = builder1.Build()
+
+    self.assertEqual(computation0.Hash(), computation1.Hash())
+
+
 class ComputationsWithConstantsTest(ComputationTest):
   """Tests focusing on Constant ops."""
 

From 76a2232f78434361be41e9a41965ef1d3cf42e19 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev 
Date: Mon, 2 Dec 2019 06:30:19 -0800
Subject: [PATCH 138/279] Lower linalg.indexed_generic with libcall to LLVM.

PiperOrigin-RevId: 283328994
Change-Id: Icc5e580fac279f76d55d715d3d4a83d97f481670
---
 .../Conversion/LinalgToLLVM/LinalgToLLVM.cpp  | 58 ++++++++++++++++++-
 1 file changed, 57 insertions(+), 1 deletion(-)

diff --git a/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index ebb0fd75753..ff516d7ef29 100644
--- a/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -340,6 +340,28 @@ public:
   }
 };
 
+template 
+static SmallVector ExtractOperandTypes(Operation *op) {
+  return SmallVector{op->getOperandTypes()};
+}
+
+template <>
+SmallVector ExtractOperandTypes(Operation *op) {
+  auto ctx = op->getContext();
+  auto indexedGenericOp = cast(op);
+  auto numLoops = indexedGenericOp.getNumLoops();
+
+  SmallVector result;
+  result.reserve(numLoops + op->getNumOperands());
+  for (unsigned i = 0; i < numLoops; ++i) {
+    result.push_back(IndexType::get(ctx));
+  }
+  for (auto type : op->getOperandTypes()) {
+    result.push_back(type);
+  }
+  return result;
+}
+
 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
 // If the library function does not exist, insert a declaration.
 template 
@@ -359,7 +381,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
     return fnNameAttr;
   }
 
-  SmallVector inputTypes(op->getOperandTypes());
+  SmallVector inputTypes(ExtractOperandTypes(op));
   assert(op->getNumResults() == 0 &&
          "Library call for linalg operation can be generated only for ops that "
          "have void return types");
@@ -430,6 +452,40 @@ public:
   }
 };
 
+/// Conversion pattern specialization for IndexedGenericOp.
+template <>
+class LinalgOpConversion
+    : public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(IndexedGenericOp op,
+                                     PatternRewriter &rewriter) const override {
+    auto libraryCallName =
+        getLibraryCallSymbolRef(op, rewriter);
+    if (!libraryCallName)
+      return this->matchFailure();
+
+    // TODO(pifon, ntv): Use induction variables values instead of zeros, when
+    // IndexedGenericOp is tiled.
+    auto zero = rewriter.create(
+        op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
+    auto indexedGenericOp = cast(op);
+    auto numLoops = indexedGenericOp.getNumLoops();
+    SmallVector operands;
+    operands.reserve(numLoops + op.getNumOperands());
+    for (unsigned i = 0; i < numLoops; ++i) {
+      operands.push_back(zero);
+    }
+    for (auto operand : op.getOperands()) {
+      operands.push_back(operand);
+    }
+    rewriter.replaceOpWithNewOp(op, libraryCallName.getValue(),
+                                              ArrayRef{}, operands);
+    return this->matchSuccess();
+  }
+};
+
 /// A non-conversion rewrite pattern kicks in to convert CopyOp with
 /// permutations into a sequence of TransposeOp and permutation-free CopyOp.
 /// This interplays together with TransposeOpConversion and

From f01d866b4a02c60a207eb7c0170c0f57801d4b85 Mon Sep 17 00:00:00 2001
From: JKIsaacLee <51275047+JKIsaacLee@users.noreply.github.com>
Date: Mon, 2 Dec 2019 07:08:31 -0800
Subject: [PATCH 139/279] add missing '>' in Ch-2

add missing '>' in Ch-2
(tensor<2x3xf64)->(tensor<2x3xf64>)

Closes #283

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/283 from JKIsaacLee:patch-1 b69fe8d51e2a540f7efaded159d35b88778ad159
PiperOrigin-RevId: 283333807
Change-Id: Id4657d664066fde23a93a042d9fce5d4af177c17
---
 third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
index 056b25779cb..3e7c680bcb8 100755
--- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
+++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
@@ -72,7 +72,7 @@ Let's break down the anatomy of this MLIR operation:
         are always constant. Here we define a boolean attribute named 'inplace'
         that has a constant value of true.
 
--   `(tensor<2x3xf64) -> tensor<3x2xf64>`
+-   `(tensor<2x3xf64>) -> tensor<3x2xf64>`
 
     *   This refers to the type of the operation in a functional form, spelling
         the types of the arguments in parentheses and the type of the return

From b4deb584386f48d04ce2dac38643ca3be5466284 Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Mon, 2 Dec 2019 07:51:27 -0800
Subject: [PATCH 140/279] NFC: Update std.subview op to use
 AttrSizedOperandSegments

This turns a few manually written helper methods into auto-generated ones.

PiperOrigin-RevId: 283339617
Change-Id: I6d42476ad15eef850aa0398ae09465099bfda67c
---
 .../include/mlir/Dialect/StandardOps/Ops.td   |  50 +++----
 third_party/mlir/include/mlir/IR/Builders.h   |   2 +
 .../StandardToLLVM/ConvertStandardToLLVM.cpp  |   3 +-
 .../mlir/lib/Dialect/StandardOps/Ops.cpp      | 123 +++++++-----------
 third_party/mlir/lib/IR/Builders.cpp          |   8 ++
 5 files changed, 77 insertions(+), 109 deletions(-)

diff --git a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
index e2731acf47f..70cf3bb7775 100644
--- a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1248,7 +1248,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> {
   let hasCanonicalizer = 1;
 }
 
-def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
+def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
   let summary = "memref subview operation";
   let description = [{
     The "subview" operation converts a memref type to another memref type
@@ -1356,23 +1356,25 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
 
   // TODO(b/144779634, ravishankarm) : Use different arguments for
   // offsets, sizes and strides.
-  let arguments = (ins AnyMemRef:$source, I32Attr:$num_offsets,
-                   I32Attr:$num_sizes, I32Attr:$num_strides,
-                   Variadic:$operands);
+  let arguments = (ins
+    AnyMemRef:$source,
+    Variadic:$offsets,
+    Variadic:$sizes,
+    Variadic:$strides,
+    I32ElementsAttr:$operand_segment_sizes
+  );
   let results = (outs AnyMemRef);
 
-  let builders = [OpBuilder<
-    "Builder *b, OperationState &result, Value *source, "
-    "ArrayRef offsets, ArrayRef sizes, "
-    "ArrayRef strides, Type resultType = Type(), "
-    "ArrayRef attrs = {}">,
+  let builders = [
     OpBuilder<
-    "Builder *builder, OperationState &result, Type resultType, Value *source">,
+      "Builder *b, OperationState &result, Value *source, "
+      "ArrayRef offsets, ArrayRef sizes, "
+      "ArrayRef strides, Type resultType = Type(), "
+      "ArrayRef attrs = {}">,
     OpBuilder<
-    "Builder *builder, OperationState &result, Type resultType, Value *source, "
-    "unsigned num_offsets, unsigned num_sizes, unsigned num_strides, "
-    "ArrayRef offsets, ArrayRef sizes, "
-    "ArrayRef strides">];
+      "Builder *builder, OperationState &result, "
+      "Type resultType, Value *source">
+  ];
 
   let extraClassDeclaration = [{
     /// Returns the type of the base memref operand.
@@ -1384,28 +1386,16 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
     MemRefType getType() { return getResult()->getType().cast(); }
 
     /// Returns as integer value the number of offset operands.
-    int64_t getNumOffsets() {
-      return num_offsets().getSExtValue();
-    }
+    int64_t getNumOffsets() { return llvm::size(offsets()); }
 
     /// Returns as integer value the number of size operands.
-    int64_t getNumSizes() {
-      return num_sizes().getSExtValue();
-    }
+    int64_t getNumSizes() { return llvm::size(sizes()); }
 
     /// Returns as integer value the number of stride operands.
-    int64_t getNumStrides() {
-      return num_strides().getSExtValue();
-    }
-
-    /// Returns the dynamic offsets for this subview operation.
-    operand_range getDynamicOffsets();
+    int64_t getNumStrides() { return llvm::size(strides()); }
 
     /// Returns the dynamic sizes for this subview operation if specified.
-    operand_range getDynamicSizes();
-
-    /// Returns the dynamic strides for this subview operation if specified.
-    operand_range getDynamicStrides();
+    operand_range getDynamicSizes() { return sizes(); }
 
     // Auxiliary range data structure and helper function that unpacks the
     // offset, size and stride operands of the SubViewOp into a list of triples.
diff --git a/third_party/mlir/include/mlir/IR/Builders.h b/third_party/mlir/include/mlir/IR/Builders.h
index 01ad38cfc11..c5ed7b16b56 100644
--- a/third_party/mlir/include/mlir/IR/Builders.h
+++ b/third_party/mlir/include/mlir/IR/Builders.h
@@ -120,6 +120,8 @@ public:
   IntegerAttr getI32IntegerAttr(int32_t value);
   IntegerAttr getI64IntegerAttr(int64_t value);
 
+  DenseIntElementsAttr getI32VectorAttr(ArrayRef values);
+
   ArrayAttr getAffineMapArrayAttr(ArrayRef values);
   ArrayAttr getI32ArrayAttr(ArrayRef values);
   ArrayAttr getI64ArrayAttr(ArrayRef values);
diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index ae2b7837c40..d226766a3fc 100644
--- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1476,7 +1476,6 @@ struct SubViewOpLowering : public LLVMLegalizationPattern {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op->getLoc();
     auto viewOp = cast(op);
-    SubViewOpOperandAdaptor adaptor(operands);
     // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support
     // having multiple variadic operands where each operand can have different
     // number of entries, clean all of this up.
@@ -1518,7 +1517,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern {
       return matchFailure();
 
     // Create the descriptor.
-    MemRefDescriptor sourceMemRef(adaptor.source());
+    MemRefDescriptor sourceMemRef(operands.front());
     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
 
     // Copy the buffer pointer from the old descriptor to the new one.
diff --git a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
index 0bf562337a9..31431be5054 100644
--- a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -1370,7 +1370,7 @@ OpFoldResult DimOp::fold(ArrayRef operands) {
   // Fold dim to the size argument of a SubViewOp.
   auto memref = memrefOrTensor()->getDefiningOp();
   if (auto subview = dyn_cast_or_null(memref)) {
-    auto sizes = subview.getDynamicSizes();
+    auto sizes = subview.sizes();
     if (!sizes.empty())
       return *(sizes.begin() + getIndex());
   }
@@ -2563,35 +2563,23 @@ static Type inferSubViewResultType(MemRefType memRefType) {
                          memRefType.getMemorySpace());
 }
 
-void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
-                            Value *source, unsigned num_offsets,
-                            unsigned num_sizes, unsigned num_strides,
-                            ArrayRef offsets, ArrayRef sizes,
-                            ArrayRef strides) {
-  SmallVector operands;
-  operands.reserve(num_offsets + num_sizes + num_strides);
-  operands.append(offsets.begin(), offsets.end());
-  operands.append(sizes.begin(), sizes.end());
-  operands.append(strides.begin(), strides.end());
-  build(b, result, resultType, source, b->getI32IntegerAttr(num_offsets),
-        b->getI32IntegerAttr(num_sizes), b->getI32IntegerAttr(num_strides),
-        operands);
-}
-
 void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source,
                             ArrayRef offsets, ArrayRef sizes,
                             ArrayRef strides, Type resultType,
                             ArrayRef attrs) {
   if (!resultType)
     resultType = inferSubViewResultType(source->getType().cast());
-  build(b, result, resultType, source, offsets.size(), sizes.size(),
-        strides.size(), offsets, sizes, strides);
+  auto segmentAttr = b->getI32VectorAttr(
+      {1, static_cast(offsets.size()), static_cast(sizes.size()),
+       static_cast(strides.size())});
+  build(b, result, resultType, source, offsets, sizes, strides, segmentAttr);
   result.addAttributes(attrs);
 }
 
 void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
                             Value *source) {
-  build(b, result, resultType, source, 0, 0, 0, {}, {}, {});
+  build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
+        resultType);
 }
 
 static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
@@ -2607,12 +2595,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
       parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) {
     return failure();
   }
+
   auto builder = parser.getBuilder();
-  result.addAttribute("num_offsets",
-                      builder.getI32IntegerAttr(offsetsInfo.size()));
-  result.addAttribute("num_sizes", builder.getI32IntegerAttr(sizesInfo.size()));
-  result.addAttribute("num_strides",
-                      builder.getI32IntegerAttr(stridesInfo.size()));
+  result.addAttribute(
+      SubViewOp::getOperandSegmentSizeAttr(),
+      builder.getI32VectorAttr({1, static_cast(offsetsInfo.size()),
+                                static_cast(sizesInfo.size()),
+                                static_cast(stridesInfo.size())}));
 
   return failure(
       parser.parseOptionalAttrDict(result.attributes) ||
@@ -2627,14 +2616,15 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
 
 static void print(OpAsmPrinter &p, SubViewOp op) {
   p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
-  p.printOperands(op.getDynamicOffsets());
+  p.printOperands(op.offsets());
   p << "][";
-  p.printOperands(op.getDynamicSizes());
+  p.printOperands(op.sizes());
   p << "][";
-  p.printOperands(op.getDynamicStrides());
+  p.printOperands(op.strides());
   p << ']';
-  SmallVector elidedAttrs = {"num_offsets", "num_sizes",
-                                           "num_strides"};
+
+  SmallVector elidedAttrs = {
+      SubViewOp::getOperandSegmentSizeAttr()};
   p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
   p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
 }
@@ -2689,14 +2679,16 @@ static LogicalResult verify(SubViewOp op) {
   }
 
   // Verify that if the shape of the subview type is static, then sizes are not
-  // dynamic values, and viceversa.
+  // dynamic values, and vice versa.
   if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
       (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
     return op.emitError("invalid to specify dynamic sizes when subview result "
                         "type is statically shaped and viceversa");
   }
+
+  // Verify that if dynamic sizes are specified, then the result memref type
+  // have full dynamic dimensions.
   if (op.getNumSizes() > 0) {
-    // Verify that non if the shape values of the result type are static.
     if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
           return dim != ShapedType::kDynamicSize;
         })) {
@@ -2758,9 +2750,8 @@ SmallVector SubViewOp::getRanges() {
   unsigned rank = getType().getRank();
   res.reserve(rank);
   for (unsigned i = 0; i < rank; ++i)
-    res.emplace_back(Range{*(getDynamicOffsets().begin() + i),
-                           *(getDynamicSizes().begin() + i),
-                           *(getDynamicStrides().begin() + i)});
+    res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
+                           *(strides().begin() + i)});
   return res;
 }
 
@@ -2792,13 +2783,13 @@ public:
     // Follow all or nothing approach for shapes for now. If all the operands
     // for sizes are constants then fold it into the type of the result memref.
     if (subViewType.hasStaticShape() ||
-        llvm::any_of(subViewOp.getDynamicSizes(), [](Value *operand) {
+        llvm::any_of(subViewOp.sizes(), [](Value *operand) {
           return !matchPattern(operand, m_ConstantIndex());
         })) {
       return matchFailure();
     }
     SmallVector staticShape(subViewOp.getNumSizes());
-    for (auto size : enumerate(subViewOp.getDynamicSizes())) {
+    for (auto size : enumerate(subViewOp.sizes())) {
       auto defOp = size.value()->getDefiningOp();
       assert(defOp);
       staticShape[size.index()] = cast(defOp).getValue();
@@ -2808,12 +2799,12 @@ public:
         subViewType.getMemorySpace());
     auto newSubViewOp = rewriter.create(
         subViewOp.getLoc(), subViewOp.source(),
-        llvm::to_vector<4>(subViewOp.getDynamicOffsets()), ArrayRef(),
-        llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType);
+        llvm::to_vector<4>(subViewOp.offsets()), ArrayRef(),
+        llvm::to_vector<4>(subViewOp.strides()), newMemRefType);
     // Insert a memref_cast for compatibility of the uses of the op.
     rewriter.replaceOpWithNewOp(
-        llvm::to_vector<4>(subViewOp.getDynamicSizes()), subViewOp,
-        newSubViewOp, subViewOp.getType());
+        llvm::to_vector<4>(subViewOp.sizes()), subViewOp, newSubViewOp,
+        subViewOp.getType());
     return matchSuccess();
   }
 };
@@ -2839,14 +2830,14 @@ public:
         failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
         llvm::is_contained(baseStrides,
                            MemRefType::getDynamicStrideOrOffset()) ||
-        llvm::any_of(subViewOp.getDynamicStrides(), [](Value *stride) {
+        llvm::any_of(subViewOp.strides(), [](Value *stride) {
           return !matchPattern(stride, m_ConstantIndex());
         })) {
       return matchFailure();
     }
 
     SmallVector staticStrides(subViewOp.getNumStrides());
-    for (auto stride : enumerate(subViewOp.getDynamicStrides())) {
+    for (auto stride : enumerate(subViewOp.strides())) {
       auto defOp = stride.value()->getDefiningOp();
       assert(defOp);
       assert(baseStrides[stride.index()] > 0);
@@ -2858,15 +2849,15 @@ public:
     MemRefType newMemRefType =
         MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
                         layoutMap, subViewType.getMemorySpace());
-    auto newSubViewOp = rewriter.create(
-        subViewOp.getLoc(), subViewOp.source(),
-        llvm::to_vector<4>(subViewOp.getDynamicOffsets()),
-        llvm::to_vector<4>(subViewOp.getDynamicSizes()), ArrayRef(),
-        newMemRefType);
+    auto newSubViewOp =
+        rewriter.create(subViewOp.getLoc(), subViewOp.source(),
+                                   llvm::to_vector<4>(subViewOp.offsets()),
+                                   llvm::to_vector<4>(subViewOp.sizes()),
+                                   ArrayRef(), newMemRefType);
     // Insert a memref_cast for compatibility of the uses of the op.
     rewriter.replaceOpWithNewOp(
-        llvm::to_vector<4>(subViewOp.getDynamicStrides()), subViewOp,
-        newSubViewOp, subViewOp.getType());
+        llvm::to_vector<4>(subViewOp.strides()), subViewOp, newSubViewOp,
+        subViewOp.getType());
     return matchSuccess();
   }
 };
@@ -2893,14 +2884,14 @@ public:
         llvm::is_contained(baseStrides,
                            MemRefType::getDynamicStrideOrOffset()) ||
         baseOffset == MemRefType::getDynamicStrideOrOffset() ||
-        llvm::any_of(subViewOp.getDynamicOffsets(), [](Value *stride) {
+        llvm::any_of(subViewOp.offsets(), [](Value *stride) {
           return !matchPattern(stride, m_ConstantIndex());
         })) {
       return matchFailure();
     }
 
     auto staticOffset = baseOffset;
-    for (auto offset : enumerate(subViewOp.getDynamicOffsets())) {
+    for (auto offset : enumerate(subViewOp.offsets())) {
       auto defOp = offset.value()->getDefiningOp();
       assert(defOp);
       assert(baseStrides[offset.index()] > 0);
@@ -2915,39 +2906,17 @@ public:
                         layoutMap, subViewType.getMemorySpace());
     auto newSubViewOp = rewriter.create(
         subViewOp.getLoc(), subViewOp.source(), ArrayRef(),
-        llvm::to_vector<4>(subViewOp.getDynamicSizes()),
-        llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType);
+        llvm::to_vector<4>(subViewOp.sizes()),
+        llvm::to_vector<4>(subViewOp.strides()), newMemRefType);
     // Insert a memref_cast for compatibility of the uses of the op.
     rewriter.replaceOpWithNewOp(
-        llvm::to_vector<4>(subViewOp.getDynamicOffsets()), subViewOp,
-        newSubViewOp, subViewOp.getType());
+        llvm::to_vector<4>(subViewOp.offsets()), subViewOp, newSubViewOp,
+        subViewOp.getType());
     return matchSuccess();
   }
 };
 
 } // end anonymous namespace
-SubViewOp::operand_range SubViewOp::getDynamicOffsets() {
-  auto numOffsets = getNumOffsets();
-  assert(getNumOperands() >= numOffsets + 1);
-  return {operand_begin() + 1, operand_begin() + 1 + numOffsets};
-}
-
-SubViewOp::operand_range SubViewOp::getDynamicSizes() {
-  auto numSizes = getNumSizes();
-  auto numOffsets = getNumOffsets();
-  assert(getNumOperands() >= numSizes + numOffsets + 1);
-  return {operand_begin() + 1 + numOffsets,
-          operand_begin() + 1 + numOffsets + numSizes};
-}
-
-SubViewOp::operand_range SubViewOp::getDynamicStrides() {
-  auto numSizes = getNumSizes();
-  auto numOffsets = getNumOffsets();
-  auto numStrides = getNumStrides();
-  assert(getNumOperands() >= numSizes + numOffsets + numStrides + 1);
-  return {operand_begin() + (1 + numOffsets + numSizes),
-          operand_begin() + (1 + numOffsets + numSizes + numStrides)};
-}
 
 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
diff --git a/third_party/mlir/lib/IR/Builders.cpp b/third_party/mlir/lib/IR/Builders.cpp
index afdeefd023c..4d6cd3550ca 100644
--- a/third_party/mlir/lib/IR/Builders.cpp
+++ b/third_party/mlir/lib/IR/Builders.cpp
@@ -100,6 +100,14 @@ IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
 }
 
+DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef values) {
+  return DenseElementsAttr::get(
+             VectorType::get(static_cast(values.size()),
+                             getIntegerType(32)),
+             values)
+      .cast();
+}
+
 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
 }

From 218327839a153f92d2f5e63fab6384eb64a88e4f Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Mon, 2 Dec 2019 07:54:23 -0800
Subject: [PATCH 141/279] [DRR] Introduce `$_` to ignore op argument match

Right now op argument matching in DRR is position-based, meaning we need to
specify N arguments for an op with N ODS-declared argument. This can be annoying
when we don't want to capture all the arguments. `$_` is to remedy the situation.

PiperOrigin-RevId: 283339992
Change-Id: I629efe95319d8b7d55edeed77f8492381f8c910e
---
 third_party/mlir/g3doc/DeclarativeRewrites.md      |  7 +++++++
 third_party/mlir/lib/TableGen/Pattern.cpp          |  3 ++-
 third_party/mlir/test/lib/TestDialect/TestOps.td   | 12 ++++++++++++
 third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp |  7 +++++--
 4 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/third_party/mlir/g3doc/DeclarativeRewrites.md b/third_party/mlir/g3doc/DeclarativeRewrites.md
index 2d9fb5b5219..5117ebd5c84 100644
--- a/third_party/mlir/g3doc/DeclarativeRewrites.md
+++ b/third_party/mlir/g3doc/DeclarativeRewrites.md
@@ -144,6 +144,13 @@ Also note that we only need to add `TypeConstraint` or `AttributeConstraint`
 when we need to further limit the match criteria. If all valid cases to the op
 are acceptable, then we can leave the constraint unspecified.
 
+`$_` is a special symbol to mean ignore capturing an argument. For example,
+`def : Pat<(AOp $_, $b), ...>` means only `$b` is interesting to capture and
+will be referenced later in result patterns. It's still possible to place
+additional constraints even if the symbol is not to be captured; for such case,
+you can simply use just the `TypeConstraint` or `AttributeConstraint` without a
+bound symbol, for example, `def : Pat<(AOp $a, F32Attr), ...>`.
+
 #### Matching DAG of operations
 
 To match an DAG of ops, use nested `dag` objects:
diff --git a/third_party/mlir/lib/TableGen/Pattern.cpp b/third_party/mlir/lib/TableGen/Pattern.cpp
index d3c1dddd21e..098dba3ae6e 100644
--- a/third_party/mlir/lib/TableGen/Pattern.cpp
+++ b/third_party/mlir/lib/TableGen/Pattern.cpp
@@ -564,7 +564,8 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
       // We can only bind symbols to op arguments in source pattern. Those
       // symbols are referenced in result patterns.
       auto treeArgName = tree.getArgName(i);
-      if (!treeArgName.empty()) {
+      // `$_` is a special symbol meaning ignore the current argument.
+      if (!treeArgName.empty() && treeArgName != "_") {
         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
                                 << treeArgName << '\n');
         if (!infoMap.bindOpArgument(treeArgName, op, i)) {
diff --git a/third_party/mlir/test/lib/TestDialect/TestOps.td b/third_party/mlir/test/lib/TestDialect/TestOps.td
index 6bb0cbc7f0c..6952eaa7717 100644
--- a/third_party/mlir/test/lib/TestDialect/TestOps.td
+++ b/third_party/mlir/test/lib/TestDialect/TestOps.td
@@ -479,6 +479,18 @@ def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>;
 def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>;
 def : Pat<(OpJ), (OpK)>;
 
+// Test `$_` for ignoring op argument match.
+def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> {
+  let arguments = (ins
+    AnyType:$a, AnyType:$b, AnyType:$c,
+    AnyAttr:$d, AnyAttr:$e, AnyAttr:$f);
+}
+def TestIgnoreArgMatchDstOp : TEST_Op<"ignore_arg_match_dst"> {
+  let arguments = (ins AnyType:$b, AnyAttr:$f);
+}
+def : Pat<(TestIgnoreArgMatchSrcOp $_, $b, I32, I64Attr:$_, $_, $f),
+          (TestIgnoreArgMatchDstOp $b, $f)>;
+
 def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> {
   let arguments = (ins
     I32:$input1,
diff --git a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
index d2776e05805..a2cace7fb60 100644
--- a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -315,7 +315,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
 
   // Capture the value
   auto name = tree.getArgName(argIndex);
-  if (!name.empty()) {
+  // `$_` is a special symbol to ignore op argument matching.
+  if (!name.empty() && name != "_") {
     // We need to subtract the number of attributes before this operand to get
     // the index in the operand list.
     auto numPrevAttrs = std::count_if(
@@ -329,6 +330,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
 
 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
                                         int indent) {
+
   Operator &op = tree.getDialectOp(opMap);
   auto *namedAttr = op.getArg(argIndex).get();
   const auto &attr = namedAttr->attr;
@@ -371,7 +373,8 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
 
   // Capture the value
   auto name = tree.getArgName(argIndex);
-  if (!name.empty()) {
+  // `$_` is a special symbol to ignore op argument matching.
+  if (!name.empty() && name != "_") {
     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
   }
 

From f92e85d6f879d92ca0f1d9a426a93a7f2d67aae5 Mon Sep 17 00:00:00 2001
From: Dan Moldovan 
Date: Mon, 2 Dec 2019 07:55:47 -0800
Subject: [PATCH 142/279] Correctly look up __call__ in the class of an object,
 rather than the object itself. An example where they are not the same are
 callable metaclasses that can create callable classes.

PiperOrigin-RevId: 283340155
Change-Id: I5d5dac0beab3a14b1986fa78e8992dbe766c69d3
---
 tensorflow/python/autograph/impl/api.py      | 15 +++------
 tensorflow/python/autograph/impl/api_test.py | 33 ++++++++++++++++++++
 2 files changed, 37 insertions(+), 11 deletions(-)

diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index b8b0eeee63c..26e766598d7 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -508,19 +508,12 @@ def converted_call(f,
       else:
         effective_args = args
 
-    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
-      # Callable objects
-      target_entity = f.__call__
+    elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
+      # Callable objects. Dunder methods have special lookup rules, see:
+      # https://docs.python.org/3/reference/datamodel.html#specialnames
+      target_entity = f.__class__.__call__
       effective_args = (f,) + args
 
-    elif tf_inspect.isclass(f):
-      # Constructors
-      # Note: Until we support class constructurs, and enable whole-class
-      # conversion with an experimental flag, this branch is dead code.
-      # TODO(mdan): Consider removing unless there is a compelling use case.
-      target_entity = f
-      effective_args = args
-
     else:
       target_entity = f
       raise NotImplementedError('unknown callable type "%s"' % type(f))
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index a3d9a870def..2eac1fefd54 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import abc
 import collections
 import contextlib
 import functools
@@ -416,12 +417,21 @@ class ApiTest(test.TestCase):
 
   def test_converted_call_callable_metaclass(self):
 
+    test_self = self
+
     class TestMetaclass(type):
 
       def __call__(cls):
         self.assertTrue(converter_testing.is_inside_generated_code())
         inst = object.__new__(cls)
         inst.__init__()
+
+        def instance_call(unused_self):
+          test_self.fail(
+              'The class-bound __call__ should be called, not the instance'
+              ' bound one.')
+
+        inst.__call__ = instance_call
         return inst
 
     tmc = TestMetaclass('TestClass', (), {})
@@ -431,6 +441,29 @@ class ApiTest(test.TestCase):
         functools.partial(tmc), (), None, options=DEFAULT_RECURSIVE)
     self.assertIsInstance(tc, tmc)
 
+  def test_converted_call_callable_abc(self):
+
+    test_self = self
+
+    @six.add_metaclass(abc.ABCMeta)
+    class TestBase(object):
+
+      @abc.abstractmethod
+      def __call__(self):
+        test_self.fail('This should not be called')
+
+    class TestSubclass(TestBase):
+
+      def __init__(self):
+        test_self.assertFalse(converter_testing.is_inside_generated_code())
+
+      def __call__(self, expected):
+        test_self.assertTrue(expected)
+        test_self.assertTrue(converter_testing.is_inside_generated_code())
+
+    tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE)
+    api.converted_call(tc, (True,), None, options=DEFAULT_RECURSIVE)
+
   @test_util.run_deprecated_v1
   def test_converted_call_constructor(self):
 

From 5247c0d46c9a0613b7573880bbd7eabd6c693b87 Mon Sep 17 00:00:00 2001
From: Denis Khalikov 
Date: Mon, 2 Dec 2019 07:58:38 -0800
Subject: [PATCH 143/279] Add missing `>` to the description of std.view.

Closes #266

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/266 from denis0x0D:sandbox/miss_char a5f662e1bf103b5009da67d045ee2fcebf822ab0
PiperOrigin-RevId: 283340486
Change-Id: I062fead03e37c106b5a4c83917e870d4dab04aea
---
 third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 70cf3bb7775..7617d3cb247 100644
--- a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1206,7 +1206,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> {
 
     // ViewOp with dynamic offset and one dynamic size.
     %2 = view %0[%offset_1024][%size0]
-      : memref<2048xi8> to memref (d0 * 4 + d1 + s0)
+      : memref<2048xi8> to memref (d0 * 4 + d1 + s0)>
 
     // ViewOp creating 3D shape where two of the dim sizes are dynamic.
     // *) The dynamic offset specified in the ViewOp is applied to the
@@ -1219,7 +1219,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> {
     //    shape and dynamic sizes.
     %3 = view %0[%offset_1024][%size0, %size1]
       : memref<2048xi8> to memref (d0 * s1 + d1 * 4 + d2 + s0)
+        (d0, d1, d2)[s0, s1] -> (d0 * s1 + d1 * 4 + d2 + s0)>
   }];
 
   let arguments = (ins MemRefRankOf<[I8], [1]>:$source,

From f27624a1e9e89dca3b18f0a201c08329ea06589e Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 09:02:32 -0800
Subject: [PATCH 144/279] Don't use grpc_impl:: as it is not API

PiperOrigin-RevId: 283350814
Change-Id: I63ae81d112f37eaa175ee19460dfc5d3d521dd69
---
 tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h
index a3d146ff299..a64df225f2a 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h
+++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h
@@ -24,7 +24,7 @@ namespace tpu_driver {
 
 xla::StatusOr> CreateGrpcTpuDriver(
     const TpuDriverConfig& config,
-    std::shared_ptr credentials);
+    std::shared_ptr credentials);
 
 }  // namespace tpu_driver
 

From 144cc5e670583379ee5689db4da35d1f253eb529 Mon Sep 17 00:00:00 2001
From: brett koonce 
Date: Mon, 2 Dec 2019 09:12:48 -0800
Subject: [PATCH 145/279] docs: minor spelling tweaks

Closes #262

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/262 from brettkoonce:docs-sp 6833fc8aa41edd02d8bc7c3cbb84211cb8b0334c
PiperOrigin-RevId: 283352765
Change-Id: Ic1fdbb7ea9c5a4c0a65a26be071eebd47d74e77a
---
 third_party/mlir/g3doc/DeclarativeRewrites.md | 4 ++--
 third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md  | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/third_party/mlir/g3doc/DeclarativeRewrites.md b/third_party/mlir/g3doc/DeclarativeRewrites.md
index 5117ebd5c84..c7276daccd8 100644
--- a/third_party/mlir/g3doc/DeclarativeRewrites.md
+++ b/third_party/mlir/g3doc/DeclarativeRewrites.md
@@ -486,7 +486,7 @@ on **naming convention**: a `__N` suffix is added to a symbol to indicate the
 
 #### `__N` suffix
 
-The `__N` sufix is specifying the `N`-th result as a whole (which can be
+The `__N` suffix is specifying the `N`-th result as a whole (which can be
 [variadic](#supporting-variadic-ops)). For example, we can bind a symbol to some
 multi-result op and reference a specific result later:
 
@@ -681,7 +681,7 @@ mlir-tblgen --gen-rewriters -I /path/to/mlir/include /path/to/input/td/file
 
 ### Compilation error: no matching member function for call to 'build'
 
-This is because DRR is failing to call a `build()` mehtod with result type
+This is because DRR is failing to call a `build()` method with result type
 deduction ability. See [building operations](#building-operations) for more
 details.
 
diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
index 3e7c680bcb8..c23597244dd 100755
--- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
+++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
@@ -395,7 +395,7 @@ documents.
 ```tablegen
 def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
   // Provide a summary and description for this operation. This can be used to
-  // auto-generate documenatation of the operations within our dialect.
+  // auto-generate documentation of the operations within our dialect.
   let summary = "constant operation";
   let description = [{
     Constant operation turns a literal into an SSA value. The data is attached
@@ -473,7 +473,7 @@ the implementation inline.
 ```tablegen
 def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
   // Provide a summary and description for this operation. This can be used to
-  // auto-generate documenatation of the operations within our dialect.
+  // auto-generate documentation of the operations within our dialect.
   let summary = "constant operation";
   let description = [{
     Constant operation turns a literal into an SSA value. The data is attached

From 2aecadb0d1b6c9fe2c2f4005b28b493d57cd3f64 Mon Sep 17 00:00:00 2001
From: Mehdi Amini 
Date: Mon, 2 Dec 2019 09:17:51 -0800
Subject: [PATCH 146/279] Generate dialect documentations in the doc folder for
 every dialect

Also add a mlir-doc build target to general all the docs

PiperOrigin-RevId: 283353529
Change-Id: I6e6805697f142d62f6f36d54b446adf50086e062
---
 third_party/mlir/CMakeLists.txt               | 21 +++++++++++++++++++
 .../mlir/Dialect/AffineOps/CMakeLists.txt     |  5 +----
 .../mlir/Dialect/FxpMathOps/CMakeLists.txt    |  5 +----
 .../include/mlir/Dialect/GPU/CMakeLists.txt   |  5 +----
 .../mlir/Dialect/LLVMIR/CMakeLists.txt        | 12 ++++-------
 .../mlir/Dialect/Linalg/IR/CMakeLists.txt     |  5 +----
 .../mlir/Dialect/LoopOps/CMakeLists.txt       |  5 +----
 .../mlir/Dialect/QuantOps/CMakeLists.txt      |  5 +----
 .../include/mlir/Dialect/SPIRV/CMakeLists.txt |  5 +----
 .../mlir/Dialect/VectorOps/CMakeLists.txt     |  5 +----
 10 files changed, 33 insertions(+), 40 deletions(-)

diff --git a/third_party/mlir/CMakeLists.txt b/third_party/mlir/CMakeLists.txt
index c8ffa759376..d6767fa75a8 100644
--- a/third_party/mlir/CMakeLists.txt
+++ b/third_party/mlir/CMakeLists.txt
@@ -12,6 +12,27 @@ function(mlir_tablegen ofn)
       PARENT_SCOPE)
 endfunction()
 
+function(add_mlir_dialect dialect)
+  set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
+  mlir_tablegen(${dialect}.h.inc -gen-op-decls)
+  mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
+  add_public_tablegen_target(MLIR${dialect}IncGen)
+
+  # Generate Dialect Documentation
+  tablegen(MLIR ${dialect}.md -gen-op-doc "-I${MLIR_MAIN_SRC_DIR}" "-I${MLIR_INCLUDE_DIR}")
+  set(GEN_DOC_FILE ${MLIR_BINARY_DIR}/docs/Dialects/${dialect}.md)
+  add_custom_command(
+          OUTPUT ${GEN_DOC_FILE}
+          COMMAND ${CMAKE_COMMAND} -E copy
+                  ${CMAKE_CURRENT_BINARY_DIR}/${dialect}.md
+                  ${GEN_DOC_FILE}
+          DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${dialect}.md)
+  add_custom_target(${dialect}DocGen DEPENDS ${GEN_DOC_FILE})
+  add_dependencies(mlir-doc ${dialect}DocGen)
+endfunction()
+
+add_custom_target(mlir-doc)
+
 # TODO: This is to handle the current static registration, but should be
 # factored out a bit.
 function(whole_archive_link target)
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
index 6c5a58c957b..8f812b39593 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
@@ -1,4 +1 @@
-set(LLVM_TARGET_DEFINITIONS AffineOps.td)
-mlir_tablegen(AffineOps.h.inc -gen-op-decls)
-mlir_tablegen(AffineOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRAffineOpsIncGen)
+add_mlir_dialect(AffineOps)
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
index eaf72d214f8..a8fb5e08ee5 100644
--- a/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
@@ -1,4 +1 @@
-set(LLVM_TARGET_DEFINITIONS FxpMathOps.td)
-mlir_tablegen(FxpMathOps.h.inc -gen-op-decls)
-mlir_tablegen(FxpMathOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRFxpMathOpsIncGen)
+add_mlir_dialect(FxpMathOps)
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
index 5ba59a1026c..bdb5dec79b9 100644
--- a/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
@@ -1,4 +1 @@
-set(LLVM_TARGET_DEFINITIONS GPUOps.td)
-mlir_tablegen(GPUOps.h.inc -gen-op-decls)
-mlir_tablegen(GPUOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRGPUOpsIncGen)
+add_mlir_dialect(GPUOps)
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 3e5a0346ed6..4ecc71aef08 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -4,14 +4,10 @@ mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
 mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRLLVMOpsIncGen)
-set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
-mlir_tablegen(NVVMOps.h.inc -gen-op-decls)
-mlir_tablegen(NVVMOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRNVVMOpsIncGen)
-set(LLVM_TARGET_DEFINITIONS ROCDLOps.td)
-mlir_tablegen(ROCDLOps.h.inc -gen-op-decls)
-mlir_tablegen(ROCDLOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRROCDLOpsIncGen)
+
+add_mlir_dialect(NVVMOps)
+add_mlir_dialect(ROCDLOps)
+
 set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
 mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRLLVMConversionsIncGen)
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
index b175e9ad044..2a883a138a5 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -1,7 +1,4 @@
-set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
-mlir_tablegen(LinalgOps.h.inc -gen-op-decls)
-mlir_tablegen(LinalgOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRLinalgOpsIncGen)
+add_mlir_dialect(LinalgOps)
 set(LLVM_TARGET_DEFINITIONS LinalgLibraryOps.td)
 mlir_tablegen(LinalgLibraryOps.h.inc -gen-op-decls)
 mlir_tablegen(LinalgLibraryOps.cpp.inc -gen-op-defs)
diff --git a/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
index 2d699580c04..9f5863f2be9 100644
--- a/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
@@ -1,4 +1 @@
-set(LLVM_TARGET_DEFINITIONS LoopOps.td)
-mlir_tablegen(LoopOps.h.inc -gen-op-decls)
-mlir_tablegen(LoopOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRLoopOpsIncGen)
+add_mlir_dialect(LoopOps)
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
index 3e3b9462b88..f95532ecf6e 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
@@ -1,4 +1 @@
-set(LLVM_TARGET_DEFINITIONS QuantOps.td)
-mlir_tablegen(QuantOps.h.inc -gen-op-decls)
-mlir_tablegen(QuantOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRQuantOpsIncGen)
+add_mlir_dialect(QuantOps)
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
index c18d6534261..b6759a9111b 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
@@ -3,10 +3,7 @@ mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
 mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
 add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
 
-set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
-mlir_tablegen(SPIRVOps.h.inc -gen-op-decls)
-mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRSPIRVOpsIncGen)
+add_mlir_dialect(SPIRVOps)
 
 set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
 mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
index 3849dd7ffdf..c165c5e676d 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
@@ -1,7 +1,4 @@
-set(LLVM_TARGET_DEFINITIONS VectorOps.td)
-mlir_tablegen(VectorOps.h.inc -gen-op-decls)
-mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRVectorOpsIncGen)
+add_mlir_dialect(VectorOps)
 
 set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td)
 mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters)

From b9f6bc52a2b155e61bc102d8fde74cf0c79027f7 Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Mon, 2 Dec 2019 09:33:24 -0800
Subject: [PATCH 147/279] [ODS] Generate builders taking unwrapped value and
 defaults for attributes

Existing builders generated by ODS require attributes to be passed
in as mlir::Attribute or its subclasses. This is okay foraggregate-
parameter builders, which is primarily to be used by programmatic
C++ code generation; it is inconvenient for separate-parameter
builders meant to be called in manually written C++ code because
it requires developers to wrap raw values into mlir::Attribute by
themselves.

This CL extends to generate additional builder methods that
take raw values for attributes and handles the wrapping in the
builder implementation. Additionally, if an attribute appears
late in the arguments list and has a default value, the default
value is supplied in the declaration if possible.

PiperOrigin-RevId: 283355919
Change-Id: Idd276a302f780c4313f7d53fb751167f4665150b
---
 third_party/mlir/g3doc/OpDefinitions.md       |  79 ++++++-
 .../mlir/include/mlir/TableGen/Attribute.h    |   8 +-
 .../mlir/include/mlir/TableGen/Operator.h     |   1 +
 third_party/mlir/lib/TableGen/Attribute.cpp   |   4 +-
 .../tools/mlir-tblgen/OpDefinitionsGen.cpp    | 203 ++++++++++++++----
 .../mlir/tools/mlir-tblgen/RewriterGen.cpp    |   4 +-
 6 files changed, 243 insertions(+), 56 deletions(-)

diff --git a/third_party/mlir/g3doc/OpDefinitions.md b/third_party/mlir/g3doc/OpDefinitions.md
index ea794964033..25865593800 100644
--- a/third_party/mlir/g3doc/OpDefinitions.md
+++ b/third_party/mlir/g3doc/OpDefinitions.md
@@ -382,27 +382,86 @@ def OpWithInferTypeInterfaceOp : Op<...
     [DeclareOpInterfaceMethods]> { ... }
 ```
 
-### Custom builder methods
+### Builder methods
 
-For each operation, there are two builders automatically generated based on the
-arguments and returns types:
+For each operation, there are a few builders automatically generated based on
+the arguments and returns types. For example, given the following op definition:
+
+```tablegen
+def MyOp : ... {
+  let arguments = (ins
+    I32:$i32_operand,
+    F32:$f32_operand,
+    ...,
+
+    I32Attr:$i32_attr,
+    F32Attr:$f32_attr,
+    ...
+  );
+
+  let results = (outs
+    I32:$i32_result,
+    F32:$f32_result,
+    ...
+  );
+}
+```
+
+The following builders are generated:
 
 ```c++
-static void build(Builder *, OperationState &tblgen_state,
-                  Type , Type , ...,
-                  Value , Value , ...,
-                  Attribute , Attribute , ...);
-
-static void build(Builder *, OperationState &tblgen_state,
+// All result-types/operands/attributes have one aggregate parameter.
+static void build(Builder *tblgen_builder, OperationState &tblgen_state,
                   ArrayRef resultTypes,
                   ArrayRef operands,
                   ArrayRef attributes);
+
+// Each result-type/operand/attribute has a separate parameter. The parameters
+// for attributes are of mlir::Attribute types.
+static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+                  Type i32_result, Type f32_result, ...,
+                  Value *i32_operand, Value *f32_operand, ...,
+                  IntegerAttr i32_attr, FloatAttr f32_attr, ...);
+
+// Each result-type/operand/attribute has a separate parameter. The parameters
+// for attributes are raw values unwrapped with mlir::Attribute instances.
+// (Note that this builder will not always be generated. See the following
+// explanation for more details.)
+static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+                  Type i32_result, Type f32_result, ...,
+                  Value *i32_operand, Value *f32_operand, ...,
+                  APInt i32_attr, StringRef f32_attr, ...);
+
+// (And potentially others depending on the specific op.)
 ```
 
-The above cases make sure basic uniformity so that we can create ops using the
+The first form provides basic uniformity so that we can create ops using the
 same form regardless of the exact op. This is particularly useful for
 implementing declarative pattern rewrites.
 
+The second and third forms are good for use in manually written code given that
+they provide better guarantee via signatures.
+
+The third form will be generated if any of the op's attribute has different
+`Attr.returnType` from `Attr.storageType` and we know how to build an attribute
+from an unwrapped value (i.e., `Attr.constBuilderCall` is defined.)
+Additionally, for the third form, if an attribute appearing later in the
+`arguments` list has a default value, the default value will be supplied in the
+declaration. This works for `BoolAttr`, `StrAttr`, `EnumAttr` for now and the
+list can grow in the future. So if possible, default valued attribute should be
+placed at the end of the `arguments` list to leverage this feature. (This
+behavior is essentially due to C++ function parameter default value placement
+restrictions.) Otherwise, the builder of the third form will still be generated
+but default values for the attributes not at the end of the `arguments` list
+will not be supplied in the builder's signature.
+
+And there may potentially exist other builders depending on the specific op;
+please refer to the
+[generated C++ file](#run-mlir-tblgen-to-see-the-generated-content) for the
+complete list.
+
+#### Custom builder methods
+
 However, if the above cases cannot satisfy all needs, you can define additional
 convenience build methods with `OpBuilder`.
 
diff --git a/third_party/mlir/include/mlir/TableGen/Attribute.h b/third_party/mlir/include/mlir/TableGen/Attribute.h
index 60f95156bb5..242376e24ff 100644
--- a/third_party/mlir/include/mlir/TableGen/Attribute.h
+++ b/third_party/mlir/include/mlir/TableGen/Attribute.h
@@ -81,10 +81,10 @@ public:
   // built upon.
   Attribute getBaseAttr() const;
 
-  // Returns whether this attribute has a default value's initializer.
-  bool hasDefaultValueInitializer() const;
-  // Returns the default value's initializer for this attribute.
-  StringRef getDefaultValueInitializer() const;
+  // Returns whether this attribute has a default value.
+  bool hasDefaultValue() const;
+  // Returns the default value for this attribute.
+  StringRef getDefaultValue() const;
 
   // Returns whether this attribute is optional.
   bool isOptional() const;
diff --git a/third_party/mlir/include/mlir/TableGen/Operator.h b/third_party/mlir/include/mlir/TableGen/Operator.h
index 7b636ddb79e..89fd4ed8d2e 100644
--- a/third_party/mlir/include/mlir/TableGen/Operator.h
+++ b/third_party/mlir/include/mlir/TableGen/Operator.h
@@ -103,6 +103,7 @@ public:
   llvm::iterator_range getAttributes() const;
 
   int getNumAttributes() const { return attributes.size(); }
+  int getNumNativeAttributes() const { return numNativeAttributes; }
 
   // Op attribute accessors.
   NamedAttribute &getAttribute(int index) { return attributes[index]; }
diff --git a/third_party/mlir/lib/TableGen/Attribute.cpp b/third_party/mlir/lib/TableGen/Attribute.cpp
index c2b673a7c93..ec946a855fc 100644
--- a/third_party/mlir/lib/TableGen/Attribute.cpp
+++ b/third_party/mlir/lib/TableGen/Attribute.cpp
@@ -107,12 +107,12 @@ tblgen::Attribute tblgen::Attribute::getBaseAttr() const {
   return *this;
 }
 
-bool tblgen::Attribute::hasDefaultValueInitializer() const {
+bool tblgen::Attribute::hasDefaultValue() const {
   const auto *init = def->getValueInit("defaultValue");
   return !getValueAsString(init).empty();
 }
 
-StringRef tblgen::Attribute::getDefaultValueInitializer() const {
+StringRef tblgen::Attribute::getDefaultValue() const {
   const auto *init = def->getValueInit("defaultValue");
   return getValueAsString(init);
 }
diff --git a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 864f7734f8a..dcecd1c65be 100644
--- a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -32,6 +32,8 @@
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 
+#define DEBUG_TYPE "mlir-tblgen-opdefgen"
+
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::tblgen;
@@ -113,6 +115,14 @@ static std::string getArgumentName(const Operator &op, int index) {
     return formatv("{0}_{1}", generatedArgName, index);
 }
 
+// Returns true if we can use unwrapped value for the given `attr` in builders.
+static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
+  return attr.getReturnType() != attr.getStorageType() &&
+         // We need to wrap the raw value into an attribute in the builder impl
+         // so we need to make sure that the attribute specifies how to do that.
+         !attr.getConstBuilderTemplate().empty();
+}
+
 namespace {
 // Simple RAII helper for defining ifdef-undef-endif scopes.
 class IfDefScope {
@@ -506,46 +516,66 @@ private:
   void genBuilder();
 
   // Generates the build() method that takes each result-type/operand/attribute
-  // as a stand-alone parameter. This build() method also requires specifying
-  // result types for all results.
-  void genSeparateParamBuilder();
+  // as a stand-alone parameter. Attributes will take wrapped mlir::Attribute
+  // values. The generated build() method also requires specifying result types
+  // for all results.
+  void genSeparateParamWrappedAttrBuilder();
+
+  // Generates the build() method that takes each result-type/operand/attribute
+  // as a stand-alone parameter. Attributes will take raw values without
+  // mlir::Attribute wrapper. The generated build() method also requires
+  // specifying result types for all results.
+  void genSeparateParamUnwrappedAttrBuilder();
 
   // Generates the build() method that takes a single parameter for all the
   // result types and a separate parameter for each operand/attribute.
   void genCollectiveTypeParamBuilder();
 
   // Generates the build() method that takes each operand/attribute as a
-  // stand-alone parameter. This build() method uses first operand's type
-  // as all results' types.
+  // stand-alone parameter. The generated build() method uses first operand's
+  // type as all results' types.
   void genUseOperandAsResultTypeSeparateParamBuilder();
 
   // Generates the build() method that takes all operands/attributes
-  // collectively as one parameter. This build() method uses first operand's
-  // type as all results' types.
+  // collectively as one parameter. The generated build() method uses first
+  // operand's type as all results' types.
   void genUseOperandAsResultTypeCollectiveParamBuilder();
 
   // Generates the build() method that takes each operand/attribute as a
-  // stand-alone parameter. This build() method uses first attribute's type
-  // as all result's types.
+  // stand-alone parameter. The generated build() method uses first attribute's
+  // type as all result's types.
   void genUseAttrAsResultTypeBuilder();
 
   // Generates the build() method that takes all result types collectively as
   // one parameter. Similarly for operands and attributes.
   void genCollectiveParamBuilder();
 
-  enum class TypeParamKind { None, Separate, Collective };
+  // The kind of parameter to generate for result types in builders.
+  enum class TypeParamKind {
+    None,       // No result type in parameter list.
+    Separate,   // A separate parameter for each result type.
+    Collective, // An ArrayRef for all result types.
+  };
+
+  // The kind of parameter to generate for attributes in builders.
+  enum class AttrParamKind {
+    WrappedAttr,    // A wrapped MLIR Attribute instance.
+    UnwrappedValue, // A raw value without MLIR Attribute wrapper.
+  };
 
   // Builds the parameter list for build() method of this op. This method writes
-  // to `paramList` the comma-separated parameter list. If `includeResultTypes`
-  // is true then `paramList` will also contain the parameters for all results
-  // and `resultTypeNames` will be populated with the parameter name for each
-  // result type.
+  // to `paramList` the comma-separated parameter list and updates
+  // `resultTypeNames` with the names for parameters for specifying result
+  // types. The given `typeParamKind` and `attrParamKind` controls how result
+  // types and attributes are placed in the parameter list.
   void buildParamList(std::string ¶mList,
                       SmallVectorImpl &resultTypeNames,
-                      TypeParamKind kind);
+                      TypeParamKind typeParamKind,
+                      AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
 
   // Adds op arguments and regions into operation state for build() methods.
-  void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body);
+  void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
+                                              bool isRawValueAttr = false);
 
   // Generates canonicalizer declaration for the operation.
   void genCanonicalizerDecls();
@@ -650,18 +680,18 @@ void OpEmitter::genAttrGetters() {
 
     // Return the queried attribute with the correct return type.
     auto attrVal =
-        (attr.hasDefaultValueInitializer() || attr.isOptional())
+        (attr.hasDefaultValue() || attr.isOptional())
             ? formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()", name,
                       attr.getStorageType())
             : formatv("this->getAttr(\"{0}\").cast<{1}>()", name,
                       attr.getStorageType());
     body << "  auto attr = " << attrVal << ";\n";
-    if (attr.hasDefaultValueInitializer()) {
+    if (attr.hasDefaultValue()) {
       // Returns the default value if not set.
       // TODO: this is inefficient, we are recreating the attribute for every
       // call. This should be set instead.
-      std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx,
-                                       attr.getDefaultValueInitializer());
+      std::string defaultValue =
+          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue());
       body << "    if (!attr)\n      return "
            << tgfmt(attr.getConvertFromStorageCall(),
                     &fctx.withSelf(defaultValue))
@@ -847,7 +877,7 @@ void OpEmitter::genNamedRegionGetters() {
   }
 }
 
-void OpEmitter::genSeparateParamBuilder() {
+void OpEmitter::genSeparateParamWrappedAttrBuilder() {
   std::string paramList;
   llvm::SmallVector resultNames;
   buildParamList(paramList, resultNames, TypeParamKind::Separate);
@@ -862,6 +892,42 @@ void OpEmitter::genSeparateParamBuilder() {
   }
 }
 
+void OpEmitter::genSeparateParamUnwrappedAttrBuilder() {
+  // If this op does not have native attributes at all, return directly to avoid
+  // redefining builders.
+  if (op.getNumNativeAttributes() == 0)
+    return;
+
+  bool canGenerate = false;
+  // We are generating builders that take raw values for attributes. We need to
+  // make sure the native attributes have a meaningful "unwrapped" value type
+  // different from the wrapped mlir::Attribute type to avoid redefining
+  // builders. This checks for the op has at least one such native attribute.
+  for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
+    NamedAttribute &namedAttr = op.getAttribute(i);
+    if (canUseUnwrappedRawValue(namedAttr.attr)) {
+      canGenerate = true;
+      break;
+    }
+  }
+  if (!canGenerate)
+    return;
+
+  std::string paramList;
+  llvm::SmallVector resultNames;
+  buildParamList(paramList, resultNames, TypeParamKind::Separate,
+                 AttrParamKind::UnwrappedValue);
+
+  auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
+  genCodeForAddingArgAndRegionForBuilder(m.body(), /*isRawValueAttr=*/true);
+
+  // Push all result types to the operation state.
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    m.body() << "  " << builderOpState << ".addTypes(" << resultNames[i]
+             << ");\n";
+  }
+}
+
 void OpEmitter::genCollectiveTypeParamBuilder() {
   auto numResults = op.getNumResults();
 
@@ -1006,7 +1072,8 @@ void OpEmitter::genBuilder() {
   // We generate three builders here:
   // 1. one having a stand-alone parameter for each result type / operand /
   //    attribute, and
-  genSeparateParamBuilder();
+  genSeparateParamWrappedAttrBuilder();
+  genSeparateParamUnwrappedAttrBuilder();
   // 2. one having a stand-alone parameter for each operand / attribute and
   //    an aggregated parameter for all result types, and
   genCollectiveTypeParamBuilder();
@@ -1069,15 +1136,16 @@ void OpEmitter::genCollectiveParamBuilder() {
 
 void OpEmitter::buildParamList(std::string ¶mList,
                                SmallVectorImpl &resultTypeNames,
-                               TypeParamKind kind) {
+                               TypeParamKind typeParamKind,
+                               AttrParamKind attrParamKind) {
   resultTypeNames.clear();
   auto numResults = op.getNumResults();
   resultTypeNames.reserve(numResults);
 
-  paramList = "Builder *, OperationState &";
+  paramList = "Builder *tblgen_builder, OperationState &";
   paramList.append(builderOpState);
 
-  switch (kind) {
+  switch (typeParamKind) {
   case TypeParamKind::None:
     break;
   case TypeParamKind::Separate: {
@@ -1100,10 +1168,36 @@ void OpEmitter::buildParamList(std::string ¶mList,
   } break;
   }
 
+  // Add parameters for all arguments (operands and attributes).
+
   int numOperands = 0;
   int numAttrs = 0;
 
-  // Add parameters for all arguments (operands and attributes).
+  int defaultValuedAttrStartIndex = op.getNumArgs();
+  if (attrParamKind == AttrParamKind::UnwrappedValue) {
+    // Calculate the start index from which we can attach default values in the
+    // builder declaration.
+    for (int i = op.getNumArgs() - 1; i >= 0; --i) {
+      auto *namedAttr = op.getArg(i).dyn_cast();
+      if (!namedAttr || !namedAttr->attr.hasDefaultValue())
+        break;
+
+      if (!canUseUnwrappedRawValue(namedAttr->attr))
+        break;
+
+      // Creating an APInt requires us to provide bitwidth, value, and
+      // signedness, which is complicated compared to others. Similarly
+      // for APFloat.
+      // TODO(b/144412160) Adjust the 'returnType' field of such attributes
+      // to support them.
+      StringRef retType = namedAttr->attr.getReturnType();
+      if (retType == "APInt" || retType == "APFloat")
+        break;
+
+      defaultValuedAttrStartIndex = i;
+    }
+  }
+
   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
     if (argument.is()) {
@@ -1113,24 +1207,46 @@ void OpEmitter::buildParamList(std::string ¶mList,
       paramList.append(getArgumentName(op, numOperands));
       ++numOperands;
     } else {
-      // TODO(antiagainst): Support default initializer for attributes
       const auto &namedAttr = op.getAttribute(numAttrs);
       const auto &attr = namedAttr.attr;
       paramList.append(", ");
+
       if (attr.isOptional())
         paramList.append("/*optional*/");
-      paramList.append(attr.getStorageType());
+
+      switch (attrParamKind) {
+      case AttrParamKind::WrappedAttr:
+        paramList.append(attr.getStorageType());
+        break;
+      case AttrParamKind::UnwrappedValue:
+        if (canUseUnwrappedRawValue(attr)) {
+          paramList.append(attr.getReturnType());
+        } else {
+          paramList.append(attr.getStorageType());
+        }
+        break;
+      }
       paramList.append(" ");
       paramList.append(namedAttr.name);
+
+      // Attach default value if requested and possible.
+      if (attrParamKind == AttrParamKind::UnwrappedValue &&
+          i >= defaultValuedAttrStartIndex) {
+        bool isString = attr.getReturnType() == "StringRef";
+        paramList.append(" = ");
+        if (isString)
+          paramList.append("\"");
+        paramList.append(attr.getDefaultValue());
+        if (isString)
+          paramList.append("\"");
+      }
       ++numAttrs;
     }
   }
-
-  if (numOperands + numAttrs != op.getNumArgs())
-    PrintFatalError("op arguments must be either operands or attributes");
 }
 
-void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) {
+void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
+                                                       bool isRawValueAttr) {
   // Push all operands to the result
   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
     body << "  " << builderOpState << ".addOperands(" << getArgumentName(op, i)
@@ -1139,13 +1255,25 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) {
 
   // Push all attributes to the result
   for (const auto &namedAttr : op.getAttributes()) {
-    if (!namedAttr.attr.isDerivedAttr()) {
-      bool emitNotNullCheck = namedAttr.attr.isOptional();
+    auto &attr = namedAttr.attr;
+    if (!attr.isDerivedAttr()) {
+      bool emitNotNullCheck = attr.isOptional();
       if (emitNotNullCheck) {
         body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
       }
-      body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
-                      namedAttr.name);
+      if (isRawValueAttr and canUseUnwrappedRawValue(attr)) {
+        // If this is a raw value, then we need to wrap it in an Attribute
+        // instance.
+        FmtContext fctx;
+        fctx.withBuilder("(*tblgen_builder)");
+        std::string value =
+            tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name);
+        body << formatv("  {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
+                        namedAttr.name, value);
+      } else {
+        body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
+                        namedAttr.name);
+      }
       if (emitNotNullCheck) {
         body << "  }\n";
       }
@@ -1282,8 +1410,7 @@ void OpEmitter::genVerifier() {
     body << formatv("  auto {0} = this->getAttr(\"{1}\");\n", varName,
                     attrName);
 
-    bool allowMissingAttr =
-        attr.hasDefaultValueInitializer() || attr.isOptional();
+    bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
     if (allowMissingAttr) {
       // If the attribute has a default value, then only verify the predicate if
       // set. This does effectively assume that the default value is valid.
diff --git a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
index a2cace7fb60..d321b204f4e 100644
--- a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -342,10 +342,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
       attr.getStorageType(), namedAttr->name);
 
   // TODO(antiagainst): This should use getter method to avoid duplication.
-  if (attr.hasDefaultValueInitializer()) {
+  if (attr.hasDefaultValue()) {
     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
-                               attr.getDefaultValueInitializer())
+                               attr.getDefaultValue())
                       << ";\n";
   } else if (attr.isOptional()) {
     // For a missing attribute that is optional according to definition, we

From ab0c78ef618c73b0630349bd3090a81a7e7f8320 Mon Sep 17 00:00:00 2001
From: Aart Bik 
Date: Mon, 2 Dec 2019 09:56:58 -0800
Subject: [PATCH 148/279] [VectorOps] Add legality rules to broadcast

PiperOrigin-RevId: 283360101
Change-Id: Iac9165cbadd0116af85e9d353f767a1efc80537a
---
 .../mlir/Dialect/VectorOps/VectorOps.td       | 19 ++++++++++++++++++-
 .../mlir/lib/Dialect/VectorOps/VectorOps.cpp  | 13 ++++++++++---
 2 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index c78334dd54a..c75f9fe0231 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -171,7 +171,24 @@ def Vector_BroadcastOp :
   let summary = "broadcast operation";
   let description = [{
     Broadcasts the scalar or k-D vector value in the source operand
-    to a n-D result vector such that the broadcast makes sense.
+    to a n-D result vector such that the broadcast makes sense, i.e.,
+    the source operand is duplicated to match the given rank and sizes
+    in the result vector. The legality rules are:
+    * the source operand must have the same element type as the result type
+    * a k-D vector  can be broadcast to
+      a n-D vector  if
+       * k <= n, and
+       * the sizes in the trailing dimensions n-k < i <= n with j=i+k-n
+          match exactly as s_j = t_i or s_j = 1:
+       ```
+           t_1 x   ..  t_n-k x t_n-k+1 x .. x t_i x .. x t_n
+                               s_1     x .. x s_j x .. x s_k
+                        
+       ```
+    The source operand is duplicated over all the missing leading dimensions
+    and streched over the trailing dimensions where the source has a non-equal
+    dimension of 1. These rules imply that any scalar broadcast (k=0) to any
+    shaped vector with the same element type is always legal.
 
     Examples:
     ```
diff --git a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index fe320b91439..6086531e3c7 100644
--- a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -386,10 +386,17 @@ static LogicalResult verify(BroadcastOp op) {
   if (srcVectorType) {
     const int64_t srcRank = srcVectorType.getRank();
     const int64_t dstRank = dstVectorType.getRank();
-    // TODO(ajcbik): implement proper rank testing for broadcast;
-    // this is just a temporary placeholder check.
-    if (srcRank > dstRank) {
+    if (srcRank > dstRank)
       return op.emitOpError("source rank higher than destination rank");
+    // Source has an exact match or singleton value for all trailing dimensions
+    // (all leading dimensions are simply duplicated).
+    const int64_t lead = dstRank - srcRank;
+    for (int64_t i = 0; i < srcRank; i++) {
+      const int64_t srcDim = srcVectorType.getDimSize(i);
+      const int64_t dstDim = dstVectorType.getDimSize(lead + i);
+      if (srcDim != 1 && srcDim != dstDim)
+        return op.emitOpError("dimension mismatch (")
+               << srcDim << " vs. " << dstDim << ")";
     }
   }
   return success();

From f59a632323fec4bf8b3a0192fe2dacbb71d05b49 Mon Sep 17 00:00:00 2001
From: Prakalp Srivastava 
Date: Mon, 2 Dec 2019 10:04:32 -0800
Subject: [PATCH 149/279] Let tf-mlir-translate support -split-input-file

Having a -split-input-file mode is useful in tf-mlir-translate. It would allow us to put logically related tests in the same test file for better organization.

PiperOrigin-RevId: 283362105
Change-Id: Ia6eec42221f34677f49be9a4ff3da27e1671ca02
---
 .../compiler/mlir/tf_mlir_translate_main.cc   | 32 +++++++++++++++----
 1 file changed, 25 insertions(+), 7 deletions(-)

diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
index f12dde278ae..9ab31265a33 100644
--- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
+++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
@@ -24,6 +24,7 @@ limitations under the License.
 #include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
 #include "mlir/Support/FileUtilities.h"  // TF:local_config_mlir
 #include "mlir/Support/LogicalResult.h"  // TF:local_config_mlir
+#include "mlir/Support/ToolUtilities.h"  // TF:local_config_mlir
 #include "mlir/Support/TranslateClParser.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/init_mlir.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
@@ -40,6 +41,13 @@ static llvm::cl::opt output_filename(
     "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
     llvm::cl::init("-"));
 
+// NOLINTNEXTLINE
+static llvm::cl::opt splitInputFile(
+    "split-input-file",
+    llvm::cl::desc("Split the input file into pieces and process each chunk "
+                   "independently"),
+    llvm::cl::init(false));
+
 // NOLINTNEXTLINE
 static llvm::cl::opt import_saved_model(
     "savedmodel-to-mlir",
@@ -85,13 +93,12 @@ int main(int argc, char** argv) {
     return 1;
   }
 
-  mlir::MLIRContext context;
-
   if (import_saved_model) {
     std::unordered_set tags =
         absl::StrSplit(saved_model_tags, ',');
     std::vector exported_names =
         absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
+    mlir::MLIRContext context;
 
     auto module = tensorflow::SavedModelToMlirImport(
         input_filename, tags, absl::Span(exported_names),
@@ -107,12 +114,23 @@ int main(int argc, char** argv) {
       return 1;
     }
 
-    llvm::SourceMgr source_mgr;
-    source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
-    mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context);
+    // Processes the memory buffer with a new MLIRContext.
+    auto processBuffer = [&](std::unique_ptr ownedBuffer,
+                             llvm::raw_ostream& os) {
+      llvm::SourceMgr sourceMgr;
+      sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
+      mlir::MLIRContext context;
+      mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context);
+      return (*requested_translation)(sourceMgr, os, &context);
+    };
 
-    if (failed((*requested_translation)(source_mgr, output->os(), &context)))
-      return 1;
+    if (splitInputFile) {
+      if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
+                                             output->os())))
+        return 1;
+    } else {
+      if (failed(processBuffer(std::move(input), output->os()))) return 1;
+    }
   }
 
   output->keep();

From 3b3e88da1109b5d62d6b14ca373e43dff5d0f62e Mon Sep 17 00:00:00 2001
From: George Karpenkov 
Date: Mon, 2 Dec 2019 10:16:39 -0800
Subject: [PATCH 150/279] [XLA/GPU] [NFC] Remove unused function

PiperOrigin-RevId: 283364753
Change-Id: I21f2f19b592034384e2147dbcecefb1b13f6f839
---
 tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
index ae6f4e39560..345abbd0935 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
@@ -120,10 +120,6 @@ class KernelMappingScheme {
     return dims_in_blocks_;
   }
 
-  int64 GetNumberOfTilesInTotal() const {
-    return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies());
-  }
-
   int64 GetNumberOfTilesInOneBlock() const { return block_size_z_; }
 
   int64 BlockSizeZ() const { return block_size_z_; }

From f20ae9c2b5c5913ac44927e2b85efc0b3e1d30a5 Mon Sep 17 00:00:00 2001
From: George Karpenkov 
Date: Mon, 2 Dec 2019 10:22:00 -0800
Subject: [PATCH 151/279] [XLA GPU] [NFC] Use the same function to check that
 two reductions are fusible

PiperOrigin-RevId: 283365847
Change-Id: Ifde925a68fce329b812fd34acaf0925e2c28b6c6
---
 .../compiler/xla/service/gpu/gpu_fusible.cc   |  4 +--
 .../xla/service/gpu/ir_emission_utils.cc      | 28 ++++++++++++++++
 .../xla/service/gpu/ir_emission_utils.h       |  5 +++
 .../xla/service/gpu/ir_emitter_unnested.cc    | 32 +++----------------
 4 files changed, 38 insertions(+), 31 deletions(-)

diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
index 599eef4e600..24738683a19 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -154,11 +154,9 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
   // operand shape) and the reduction dimensions need to match.
   auto* instr_1 = get_real_hero(&instr1);
   auto* instr_2 = get_real_hero(&instr2);
-  // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
   if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
       IsReductionFromOrToContiguousDimensions(*instr_2) &&
-      (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) ||
-       instr_1->dimensions() != instr_2->dimensions())) {
+      !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
     return false;
   }
   // The elementwise output shapes must be the same (including layout).
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 26a6deb8030..72f69ca2017 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -405,5 +405,33 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
           EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)));
 }
 
+bool AreFusedReductionOutputsConsistent(
+    absl::Span output_instructions,
+    const HloInstruction* first_reduce) {
+  for (const HloInstruction* inst : output_instructions) {
+    if (IsReductionFromOrToContiguousDimensions(*inst)) {
+      // Shapes, layouts and dimensions must be the same for all reduces
+      // inside of this fusion.
+      // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
+      if (!(ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
+            ShapeUtil::Equal(first_reduce->operand(0)->shape(),
+                             inst->operand(0)->shape()) &&
+            ShapeUtil::Equal(first_reduce->operand(1)->shape(),
+                             inst->operand(1)->shape()) &&
+            first_reduce->dimensions() == inst->dimensions())) {
+        return false;
+      }
+    } else {
+      if (!(ShapeUtil::CompatibleIgnoringElementType(
+                first_reduce->operand(0)->shape(), inst->shape()) &&
+            LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
+                              inst->shape().layout()))) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index f269cf87062..db3cd228841 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -200,6 +200,11 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
 // block 0 of the kernel.
 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b);
 
+// Returns whether the outputs of a fusion with reduction are consistent.
+bool AreFusedReductionOutputsConsistent(
+    absl::Span output_instructions,
+    const HloInstruction* first_reduce);
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 06a00d2178a..dbc2c95773a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2779,32 +2779,6 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
 }
 
 namespace {
-// Checks that the outputs of a fusion with reduction are consistent.
-Status AreFusedReductionOutputsConsistent(
-    absl::Span output_instructions,
-    const HloInstruction* first_reduce) {
-  for (const HloInstruction* inst : output_instructions) {
-    if (IsReductionFromOrToContiguousDimensions(*inst)) {
-      // Shapes, layouts and dimensions must be the same for all reduces
-      // inside of this fusion.
-      TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
-      TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(),
-                                    inst->operand(0)->shape()));
-      TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(),
-                                    inst->operand(1)->shape()));
-      TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions());
-    } else {
-      // For extra outputs we can relax shape equality to allow different
-      // types (with the same number of elements). Layouts still have to
-      // match.
-      TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType(
-          first_reduce->operand(0)->shape(), inst->shape()));
-      TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
-                                     inst->shape().layout()));
-    }
-  }
-  return Status::OK();
-}
 
 // Returns true if all the transitive users of hlo before hitting users in
 // use_chain_endings are elementwise operations.
@@ -2994,8 +2968,10 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
 
   const HloInstruction* first_reduce = reduce_instructions.at(0);
   if (output_instructions.size() > 1) {
-    TF_RETURN_IF_ERROR(
-        AreFusedReductionOutputsConsistent(output_instructions, first_reduce));
+    if (!AreFusedReductionOutputsConsistent(output_instructions,
+                                            first_reduce)) {
+      return InternalError("Inconsistent reduction fusion outputs");
+    }
   }
 
   // Build a kernel thunk to compute all the outputs.

From 37a231e07ad204a184a369c550cbcf3abf268940 Mon Sep 17 00:00:00 2001
From: Yash Katariya 
Date: Mon, 2 Dec 2019 10:48:15 -0800
Subject: [PATCH 152/279] Doc improvements for `tf.disribute.MirroredStrategy`

PiperOrigin-RevId: 283371699
Change-Id: I09ce229f7357916eecf09040164c40cc4009389c
---
 .../python/distribute/mirrored_strategy.py    | 83 +++++++++++++++++--
 1 file changed, 75 insertions(+), 8 deletions(-)

diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index 0fb8ae0aafb..85958724002 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -343,19 +343,86 @@ def all_devices():
 
 @tf_export("distribute.MirroredStrategy", v1=[])  # pylint: disable=g-classes-have-attributes
 class MirroredStrategy(distribute_lib.Strategy):
-  """Mirrors vars to distribute across multiple devices and machines.
+  """Synchronous training across multiple replicas on one machine.
 
-  This strategy uses one replica per device and sync replication for its
-  multi-GPU version.
+  This strategy is typically used for training on one
+  machine with multiple GPUs. For TPUs, use
+  `tf.distribute.experimental.TPUStrategy`. To use `MirroredStrategy` with
+  multiple workers, please refer to
+  `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
 
-  To use `MirroredStrategy` with multiple workers, please refer to
-  `tf.distribute.MultiWorkerMirroredStrategy`.
+  For example, a variable created under a `MirroredStrategy` is a
+  `MirroredVariable`. If no devices are specified in the constructor argument of
+  the strategy then it will use all the available GPUs. If no GPUs are found, it
+  will use the available CPUs. Note that TensorFlow treats all CPUs on a
+  machine as a single device, and uses threads internally for parallelism.
+
+  >>> strategy = tf.distribute.MirroredStrategy()
+  >>> with strategy.scope():
+  ...   x = tf.Variable(1.)
+  >>> x
+  MirroredVariable:{
+      0 /job:localhost/replica:0/task:0/device:CPU:0: 
+    }
+
+  While using distribution strategies, all the variable creation should be done
+  within the strategy's scope. This will replicate the variables across all the
+  replicas and keep them in sync using an all-reduce algorithm.
+
+  Variables created inside a `MirroredStrategy` which is wrapped with a
+  `tf.function` are still `MirroredVariables`.
+
+  >>> x = []
+  >>> @tf.function  # Wrap the function with tf.function.
+  ... def create_variable():
+  ...   if not x:
+  ...     x.append(tf.Variable(1.))
+  >>> strategy = tf.distribute.MirroredStrategy()
+  >>> with strategy.scope():
+  ...   create_variable()
+  ...   print (x[0])
+  MirroredVariable:{
+      0 /job:localhost/replica:0/task:0/device:CPU:0: 
+    }
+
+  `experimental_distribute_dataset` can be used to distribute the dataset across
+  the replicas when writing your own training loop. If you are using `.fit` and
+  `.compile` methods available in `tf.keras`, then `tf.keras` will handle the
+  distribution for you.
+
+  For example:
+
+  ```python
+  my_strategy = tf.distribute.MirroredStrategy()
+  with my_strategy.scope():
+    @tf.function
+    def distribute_train_epoch(dataset):
+      def replica_fn(input):
+        # process input and return result
+        return result
+
+      total_result = 0
+      for x in dataset:
+        per_replica_result = my_strategy.experimental_run_v2(replica_fn,
+                                                             args=(x,))
+        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
+                                           per_replica_result, axis=None)
+      return total_result
+
+    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
+    for _ in range(EPOCHS):
+      train_result = distribute_train_epoch(dist_dataset)
+  ```
 
   Args:
-    devices: a list of device strings.  If `None`, all available GPUs are used.
-    If no GPUs are found, CPU is used.
+    devices: a list of device strings such as `['/gpu:0', '/gpu:1']`.  If
+      `None`, all available GPUs are used. If no GPUs are found, CPU is used.
     cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
-      set, nccl will be used by default.
+      set, `NcclAllReduce()` will be used by default.  One would customize this
+      if NCCL isn't available or if a special implementation that exploits
+      the particular hardware is available.
   """
 
   def __init__(self, devices=None, cross_device_ops=None):

From c16614ef36b990ce9633e25aa00467ea4ce85844 Mon Sep 17 00:00:00 2001
From: Jared Duke 
Date: Mon, 2 Dec 2019 10:50:10 -0800
Subject: [PATCH 153/279] Fix std::uniform_int_distribution usage in benchmark

The standard does not specify support for char/int8 types, and compilation
fails with MSVC.

Fixes #34701

PiperOrigin-RevId: 283372131
Change-Id: Ie50778bd8934dc77eb968d740178948815a40909
---
 tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc | 6 ++++--
 tensorflow/lite/tools/benchmark/benchmark_tflite_model.h  | 5 +++--
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 89c83520ab8..dc4a43ee6cb 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -492,13 +492,15 @@ TfLiteStatus BenchmarkTfLiteModel::PrepareInputData() {
     } else if (t->type == kTfLiteUInt8) {
       int low = has_value_range ? low_range : 0;
       int high = has_value_range ? high_range : 254;
+      // std::uniform_int_distribution is specified not to support char types.
       t_data = CreateInputTensorData(
-          num_elements, std::uniform_int_distribution(low, high));
+          num_elements, std::uniform_int_distribution(low, high));
     } else if (t->type == kTfLiteInt8) {
       int low = has_value_range ? low_range : -127;
       int high = has_value_range ? high_range : 127;
+      // std::uniform_int_distribution is specified not to support char types.
       t_data = CreateInputTensorData(
-          num_elements, std::uniform_int_distribution(low, high));
+          num_elements, std::uniform_int_distribution(low, high));
     } else if (t->type == kTfLiteString) {
       // TODO(haoliang): No need to cache string tensors right now.
     } else {
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
index ca7731eed33..a6fc38a6180 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
@@ -90,8 +90,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
     InputTensorData tmp;
     tmp.bytes = sizeof(T) * num_elements;
     T* raw = new T[num_elements];
-    std::generate_n(raw, num_elements,
-                    [&]() { return distribution(random_engine_); });
+    std::generate_n(raw, num_elements, [&]() {
+      return static_cast(distribution(random_engine_));
+    });
     // Now initialize the type-erased unique_ptr (with custom deleter) from
     // 'raw'.
     tmp.data = std::unique_ptr(

From 2490c876543961be3e59a112c0ed4bd91e84b522 Mon Sep 17 00:00:00 2001
From: Christian Sigg 
Date: Mon, 2 Dec 2019 11:09:57 -0800
Subject: [PATCH 154/279] Change cuDNN stub version for 7.2 to use 7.3 instead
 of 7.1.

See https://github.com/tensorflow/tensorflow/issues/32350

PiperOrigin-RevId: 283376636
Change-Id: Ifaede15a849924d891bbefe2abf986952f9031bf
---
 tensorflow/stream_executor/cuda/cudnn_stub.cc | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/tensorflow/stream_executor/cuda/cudnn_stub.cc b/tensorflow/stream_executor/cuda/cudnn_stub.cc
index 5a05437480e..f683cecdb52 100644
--- a/tensorflow/stream_executor/cuda/cudnn_stub.cc
+++ b/tensorflow/stream_executor/cuda/cudnn_stub.cc
@@ -53,7 +53,8 @@ cudnnStatus_t GetSymbolNotFoundError() { return CUDNN_STATUS_INTERNAL_ERROR; }
 #include "tensorflow/stream_executor/cuda/cudnn_6_0.inc"
 #elif CUDNN_MINOR < 1
 #include "tensorflow/stream_executor/cuda/cudnn_7_0.inc"
-#elif CUDNN_MINOR < 3
+// 2 instead of 3: see https://github.com/tensorflow/tensorflow/issues/32350
+#elif CUDNN_MINOR < 2
 #include "tensorflow/stream_executor/cuda/cudnn_7_1.inc"
 #elif CUDNN_MINOR < 4
 #include "tensorflow/stream_executor/cuda/cudnn_7_3.inc"

From 2a2c812ab2330c9aac33335f10679a346436acfb Mon Sep 17 00:00:00 2001
From: Jiri Simsa 
Date: Mon, 2 Dec 2019 11:20:05 -0800
Subject: [PATCH 155/279] [tf.data] Migrating remaining core API tests to use
 TF combinations and performing various minor test cleanup.

PiperOrigin-RevId: 283378763
Change-Id: Ice08340d289406eb691fb261c20329ada7c23c8a
---
 .../kernel_tests/matching_files_test.py       |   2 +
 .../dataset_serialization_test_base.py        |   1 +
 tensorflow/python/data/kernel_tests/BUILD     |   2 +
 .../kernel_tests/as_numpy_iterator_test.py    |  12 +-
 .../python/data/kernel_tests/cache_test.py    |   2 +-
 .../data/kernel_tests/checkpoint_test.py      |  29 +-
 .../python/data/kernel_tests/dataset_test.py  |  95 +-
 .../python/data/kernel_tests/filter_test.py   |  35 +-
 .../python/data/kernel_tests/flat_map_test.py |   6 +-
 .../data/kernel_tests/from_generator_test.py  |  98 +-
 .../data/kernel_tests/from_tensors_test.py    |   4 +-
 .../kernel_tests/iterator_cluster_test.py     |  23 +-
 .../python/data/kernel_tests/iterator_test.py |  67 +-
 .../data/kernel_tests/list_files_test.py      |   2 +
 .../python/data/kernel_tests/map_test.py      | 868 +++++++++---------
 .../data/kernel_tests/memory_cleanup_test.py  |  17 +-
 .../python/data/kernel_tests/optional_test.py | 166 ++--
 .../python/data/kernel_tests/options_test.py  |  23 +-
 .../data/kernel_tests/padded_batch_test.py    |  56 +-
 .../python/data/kernel_tests/prefetch_test.py |  23 +-
 .../python/data/kernel_tests/range_test.py    |  15 +-
 .../python/data/kernel_tests/repeat_test.py   |  50 +-
 .../python/data/kernel_tests/shard_test.py    |  19 +-
 .../python/data/kernel_tests/shuffle_test.py  |   8 +-
 .../python/data/kernel_tests/skip_test.py     |  46 +-
 .../python/data/kernel_tests/take_test.py     |  38 +-
 .../python/data/kernel_tests/test_base.py     |   6 +-
 .../kernel_tests/text_line_dataset_test.py    |  69 +-
 .../kernel_tests/tf_record_dataset_test.py    |  41 +-
 .../python/data/kernel_tests/unbatch_test.py  |  20 +-
 .../python/data/kernel_tests/window_test.py   |  61 +-
 .../python/data/kernel_tests/zip_test.py      |  50 +-
 tensorflow/python/data/ops/dataset_ops.py     |  26 +-
 33 files changed, 1018 insertions(+), 962 deletions(-)

diff --git a/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py b/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
index 1bf07a98f28..1240b704119 100644
--- a/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
@@ -34,10 +34,12 @@ class MatchingFilesDatasetTest(test_base.DatasetTestBase,
                                parameterized.TestCase):
 
   def setUp(self):
+    super(MatchingFilesDatasetTest, self).setUp()
     self.tmp_dir = tempfile.mkdtemp()
 
   def tearDown(self):
     shutil.rmtree(self.tmp_dir, ignore_errors=True)
+    super(MatchingFilesDatasetTest, self).tearDown()
 
   def _touchTempFiles(self, filenames):
     for filename in filenames:
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
index f6ab5a1cde2..aea4934260e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -57,6 +57,7 @@ class DatasetSerializationTestBase(test.TestCase):
 
   def tearDown(self):
     self._delete_ckpt()
+    super(DatasetSerializationTestBase, self).tearDown()
 
   # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
   # (deprecated) saveable `SparseTensorSliceDataset`, once the API
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 2b1bda4138a..db749da77f8 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -287,6 +287,7 @@ tf_py_test(
     size = "small",
     srcs = ["iterator_cluster_test.py"],
     additional_deps = [
+        ":test_base",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/data/ops:iterator_ops",
@@ -400,6 +401,7 @@ tf_py_test(
         "//tensorflow/python:variable_scope",
         "//tensorflow/python/ops/ragged",
     ],
+    shard_count = 4,
 )
 
 cuda_py_test(
diff --git a/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py b/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py
index 23a928a817c..b704906a3ae 100644
--- a/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py
@@ -61,30 +61,30 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(RuntimeError):
       ds.as_numpy_iterator()
 
-  def checkInvalidElement(self, element):
+  def _testInvalidElement(self, element):
     ds = dataset_ops.Dataset.from_tensors(element)
     with self.assertRaisesRegex(TypeError,
                                 '.*does not support datasets containing.*'):
       ds.as_numpy_iterator()
 
   @combinations.generate(test_base.eager_only_combinations())
-  def testInvalidElements(self):
-    self.checkInvalidElement(sparse_tensor.SparseTensorValue([[0]], [0], [1]))
+  def testSparseElement(self):
+    self._testInvalidElement(sparse_tensor.SparseTensorValue([[0]], [0], [1]))
 
   @combinations.generate(test_base.eager_only_combinations())
   def testRaggedElement(self):
-    self.checkInvalidElement(
+    self._testInvalidElement(
         ragged_tensor_value.RaggedTensorValue(
             np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))
 
   @combinations.generate(test_base.eager_only_combinations())
   def testDatasetElement(self):
-    self.checkInvalidElement(dataset_ops.Dataset.range(3))
+    self._testInvalidElement(dataset_ops.Dataset.range(3))
 
   @combinations.generate(test_base.eager_only_combinations())
   def testNestedNonTensorElement(self):
     tuple_elem = (constant_op.constant([1, 2, 3]), dataset_ops.Dataset.range(3))
-    self.checkInvalidElement(tuple_elem)
+    self._testInvalidElement(tuple_elem)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/data/kernel_tests/cache_test.py b/tensorflow/python/data/kernel_tests/cache_test.py
index 6d3dc04a3e0..1a923645a04 100644
--- a/tensorflow/python/data/kernel_tests/cache_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_test.py
@@ -45,9 +45,9 @@ class FileCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.cache_prefix = path.join(self.tmp_dir, "cache")
 
   def tearDown(self):
-    super(FileCacheTest, self).tearDown()
     if self.tmp_dir:
       shutil.rmtree(self.tmp_dir, ignore_errors=True)
+    super(FileCacheTest, self).tearDown()
 
   @combinations.generate(test_base.default_test_combinations())
   def testCacheDatasetPassthrough(self):
diff --git a/tensorflow/python/data/kernel_tests/checkpoint_test.py b/tensorflow/python/data/kernel_tests/checkpoint_test.py
index 738d09b97fe..4441d5642d5 100644
--- a/tensorflow/python/data/kernel_tests/checkpoint_test.py
+++ b/tensorflow/python/data/kernel_tests/checkpoint_test.py
@@ -42,11 +42,11 @@ from tensorflow.python.training.tracking import util as trackable_utils
 class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def tearDown(self):
-    super(CheckpointTest, self).tearDown()
     prefix = self._iterator_checkpoint_prefix()
     pattern = prefix + "*"
     files = gfile.Glob(pattern)
     map(gfile.Remove, files)
+    super(CheckpointTest, self).tearDown()
 
   def _iterator_checkpoint_prefix(self):
     return os.path.join(self.get_temp_dir(), "iterator")
@@ -66,8 +66,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
                                                       iterator_state_variant)
     return restore_op
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  @combinations.generate(test_base.graph_only_combinations())
   def testSaveRestore(self):
 
     def _build_graph(start, stop):
@@ -118,8 +117,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
         with self.assertRaises(errors.OutOfRangeError):
           sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  @combinations.generate(test_base.graph_only_combinations())
   def testInitThenRestore(self):
     # Note: Calling init_op before restore_op is redundant. This test just makes
     # sure we do not fail if restore is called on an already initialized
@@ -157,8 +155,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
         with self.assertRaises(errors.OutOfRangeError):
           sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  @combinations.generate(test_base.graph_only_combinations())
   def testMultipleSaves(self):
 
     def _build_graph(start, stop):
@@ -204,8 +201,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
         with self.assertRaises(errors.OutOfRangeError):
           sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  @combinations.generate(test_base.graph_only_combinations())
   def testSaveRestoreWithRepeat(self):
 
     def _build_graph(start, stop, num_epochs):
@@ -253,8 +249,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
         with self.assertRaises(errors.OutOfRangeError):
           sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  @combinations.generate(test_base.graph_only_combinations())
   def testSaveRestoreExhaustedIterator(self):
 
     def _build_graph(start, stop, num_epochs):
@@ -295,8 +290,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
         with self.assertRaises(errors.OutOfRangeError):
           sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testSaveRestoreOneShotIterator(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@@ -319,8 +313,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       get_next()
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testSaveRestoreMultipleIterator(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@@ -353,8 +346,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertAllEqual([1, 4], get_next_2())
     self.assertAllEqual(3, get_next_3())
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testRestoreExhaustedIterator(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@@ -373,8 +365,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       get_next()
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testRestoreInReconstructedIteratorInitializable(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
diff --git a/tensorflow/python/data/kernel_tests/dataset_test.py b/tensorflow/python/data/kernel_tests/dataset_test.py
index b35c4ff1b29..df151a85db0 100644
--- a/tensorflow/python/data/kernel_tests/dataset_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_test.py
@@ -43,7 +43,6 @@ from tensorflow.python.framework import tensor_spec
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
 
 
 class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@@ -89,13 +88,13 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
           variant, original_dataset.element_spec)
       self.assertDatasetProduces(revived_dataset, list(original_dataset))
 
-  def checkNumInputs(self, dataset, num_inputs):
+  def _testNumInputs(self, dataset, num_inputs):
     self.assertLen(dataset._inputs(), num_inputs)
 
   @combinations.generate(test_base.default_test_combinations())
   def testFixedLengthRecordInputs(self):
     dataset = readers.FixedLengthRecordDataset("", 42)
-    self.checkNumInputs(dataset, 0)
+    self._testNumInputs(dataset, 0)
 
   @combinations.generate(test_base.default_test_combinations())
   def testFromGeneratorInputs(self):
@@ -103,27 +102,27 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
       yield 42
 
     dataset = dataset_ops.Dataset.from_generator(gen, dtypes.int32)
-    self.checkNumInputs(dataset, 1)
+    self._testNumInputs(dataset, 1)
 
   @combinations.generate(test_base.default_test_combinations())
   def testFromTensorsInputs(self):
     dataset = dataset_ops.Dataset.from_tensors([42])
-    self.checkNumInputs(dataset, 0)
+    self._testNumInputs(dataset, 0)
 
   @combinations.generate(test_base.default_test_combinations())
   def testRangeInputs(self):
     dataset = dataset_ops.Dataset.range(10)
-    self.checkNumInputs(dataset, 0)
+    self._testNumInputs(dataset, 0)
 
   @combinations.generate(test_base.default_test_combinations())
   def testTextLineInputs(self):
     dataset = readers.TextLineDataset("")
-    self.checkNumInputs(dataset, 0)
+    self._testNumInputs(dataset, 0)
 
   @combinations.generate(test_base.default_test_combinations())
   def testTFRecordInputs(self):
     dataset = readers.TFRecordDataset("")
-    self.checkNumInputs(dataset, 1)
+    self._testNumInputs(dataset, 1)
 
   @combinations.generate(
       combinations.combine(tf_api_version=1, mode=["eager", "graph"]))
@@ -135,58 +134,58 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
             dense_shape=np.array([3, 1])))
     self.assertEmpty(dataset_fn._inputs())
 
-  def checkUnaryInputs(self, dataset_fn):
+  def _testUnaryInputs(self, dataset_fn):
     input_dataset = dataset_ops.Dataset.range(0)
     self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
 
   @combinations.generate(test_base.default_test_combinations())
   def testBatchInputs(self):
-    self.checkUnaryInputs(lambda x: x.batch(10))
+    self._testUnaryInputs(lambda x: x.batch(10))
 
   @combinations.generate(test_base.default_test_combinations())
   def testCacheInputs(self):
-    self.checkUnaryInputs(lambda x: x.cache())
+    self._testUnaryInputs(lambda x: x.cache())
 
   @combinations.generate(test_base.default_test_combinations())
   def testFilterInputs(self):
-    self.checkUnaryInputs(lambda x: x.filter(lambda x: True))
+    self._testUnaryInputs(lambda x: x.filter(lambda x: True))
 
   @combinations.generate(test_base.default_test_combinations())
   def testFlatMapInputs(self):
-    self.checkUnaryInputs(
+    self._testUnaryInputs(
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)))
 
   @combinations.generate(test_base.default_test_combinations())
   def testMapInputs(self):
-    self.checkUnaryInputs(lambda x: x.map(lambda x: x))
+    self._testUnaryInputs(lambda x: x.map(lambda x: x))
 
   @combinations.generate(test_base.default_test_combinations())
   def testPaddedBatchInputs(self):
-    self.checkUnaryInputs(lambda x: x.padded_batch(10, []))
+    self._testUnaryInputs(lambda x: x.padded_batch(10, []))
 
   @combinations.generate(test_base.default_test_combinations())
   def testParallelMapInputs(self):
-    self.checkUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2))
+    self._testUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2))
 
   @combinations.generate(test_base.default_test_combinations())
   def testRepeatInputs(self):
-    self.checkUnaryInputs(lambda x: x.repeat())
+    self._testUnaryInputs(lambda x: x.repeat())
 
   @combinations.generate(test_base.default_test_combinations())
   def testShuffleInputs(self):
-    self.checkUnaryInputs(lambda x: x.shuffle(10))
+    self._testUnaryInputs(lambda x: x.shuffle(10))
 
   @combinations.generate(test_base.default_test_combinations())
   def testSkipInputs(self):
-    self.checkUnaryInputs(lambda x: x.skip(1))
+    self._testUnaryInputs(lambda x: x.skip(1))
 
   @combinations.generate(test_base.default_test_combinations())
   def testTakeInputs(self):
-    self.checkUnaryInputs(lambda x: x.take(1))
+    self._testUnaryInputs(lambda x: x.take(1))
 
   @combinations.generate(test_base.default_test_combinations())
   def testWindowInputs(self):
-    self.checkUnaryInputs(lambda x: x.window(10))
+    self._testUnaryInputs(lambda x: x.window(10))
 
   @combinations.generate(test_base.default_test_combinations())
   def testUnaryTransformationInputsApply(self):
@@ -195,7 +194,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     self.assertEqual([input_dataset], dataset._inputs())
 
-  def checkInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
+  def _testInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
     input_dataset = dataset_ops.Dataset.range(0)
     dataset = input_dataset.interleave(
         lambda x: dataset_ops.Dataset.range(0),
@@ -205,11 +204,11 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(test_base.default_test_combinations())
   def testParallelInterleaveInputs(self):
-    self.checkInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)
+    self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)
 
   @combinations.generate(test_base.default_test_combinations())
   def testInterleaveInputs(self):
-    self.checkInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)
+    self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)
 
   @combinations.generate(test_base.default_test_combinations())
   def testNoWarnings(self):
@@ -218,16 +217,16 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
           lambda x: dataset_ops.Dataset.range(0), cycle_length=2)
       self.assertEmpty(mock_log.call_args_list)
 
-  def checkBinaryInputs(self, dataset_fn):
+  def _testBinaryInputs(self, dataset_fn):
     input1 = dataset_ops.Dataset.range(0)
     input2 = dataset_ops.Dataset.range(1)
     self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
 
   @combinations.generate(test_base.default_test_combinations())
   def testConcatenateInputs(self):
-    self.checkBinaryInputs(lambda x, y: x.concatenate(y))
+    self._testBinaryInputs(lambda x, y: x.concatenate(y))
 
-  def checkVariadicInputs(self, dataset_fn, input_datasets):
+  def _testVariadicInputs(self, dataset_fn, input_datasets):
     self.assertEqual(
         nest.flatten(input_datasets),
         dataset_fn(input_datasets)._inputs())
@@ -235,20 +234,20 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
   @combinations.generate(test_base.default_test_combinations())
   def testZipOneInputs(self):
     input_datasets = dataset_ops.Dataset.range(0)
-    self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
+    self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
 
   @combinations.generate(test_base.default_test_combinations())
   def testZipNestInputs(self):
     input_datasets = (dataset_ops.Dataset.range(0),
                       (dataset_ops.Dataset.range(1),
                        dataset_ops.Dataset.range(2)))
-    self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
+    self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
 
   @combinations.generate(test_base.default_test_combinations())
   def testZipTupleInputs(self):
     input_datasets = (dataset_ops.Dataset.range(0),
                       dataset_ops.Dataset.range(1))
-    self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
+    self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
 
   @combinations.generate(test_base.default_test_combinations())
   def testFunctions(self):
@@ -273,7 +272,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertEqual(2, inputs.count(ds2))
     self.assertEqual(1, inputs.count(ds3))
 
-  def checkDatasetSpec(self, tf_value, expected_element_structure):
+  def _testDatasetSpec(self, tf_value, expected_element_structure):
     dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value)
     dataset_structure = structure.type_spec_from_value(dataset)
     self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)
@@ -307,12 +306,12 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(test_base.default_test_combinations())
   def testTensorDatasetSpec(self):
-    self.checkDatasetSpec(
+    self._testDatasetSpec(
         constant_op.constant(37.0), tensor_spec.TensorSpec([], dtypes.float32))
 
   @combinations.generate(test_base.default_test_combinations())
   def testSparseTensorDatasetSpec(self):
-    self.checkDatasetSpec(
+    self._testDatasetSpec(
         sparse_tensor.SparseTensor(
             indices=[[0]],
             values=constant_op.constant([0], dtype=dtypes.int32),
@@ -320,7 +319,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(test_base.default_test_combinations())
   def testNestDatasetSpec(self):
-    self.checkDatasetSpec(
+    self._testDatasetSpec(
         {
             "a": constant_op.constant(37.0),
             "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
@@ -335,20 +334,19 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(test_base.default_test_combinations())
   def testDatasetDatasetSpec(self):
-    self.checkDatasetSpec(
+    self._testDatasetSpec(
         dataset_ops.Dataset.from_tensor_slices(
             constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))
 
   @combinations.generate(test_base.default_test_combinations())
   def testOptionalDatasetSpec(self):
-    self.checkDatasetSpec(
+    self._testDatasetSpec(
         optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32)))
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1], mode=["graph"]))
-  def testSkipEagerSameGraphErrorOneShot(self):
+  @combinations.generate(test_base.graph_only_combinations())
+  def testSameGraphError(self):
     dataset = dataset_ops.Dataset.range(10)
     with ops.Graph().as_default():
       with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
@@ -356,26 +354,27 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(
       combinations.combine(tf_api_version=[1], mode=["graph"]))
-  def testSkipEagerSameGraphErrorOneShotSimple(self):
+  def testSameGraphErrorOneShot(self):
     dataset = dataset_ops.Dataset.range(10)
     with ops.Graph().as_default():
-      with test.mock.patch.object(tf_logging, "warning") as mock_log:
+      with self.assertRaisesRegexp(
+          ValueError, "Please ensure that all datasets in the pipeline are "
+          "created in the same graph as the iterator."):
         _ = dataset_ops.make_one_shot_iterator(dataset)
-        self.assertRegexpMatches(
-            str(mock_log.call_args), "Please ensure that all datasets in the "
-            "pipeline are created in the same graph as the iterator.")
 
   @combinations.generate(
       combinations.combine(tf_api_version=[1], mode=["graph"]))
-  def testSkipEagerSameGraphErrorInitializable(self):
+  def testSameGraphErrorInitializable(self):
     dataset = dataset_ops.Dataset.range(10)
     with ops.Graph().as_default():
-      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
-        dataset = dataset.batch(2)
+      with self.assertRaisesRegexp(
+          ValueError, "Please ensure that all datasets in the pipeline are "
+          "created in the same graph as the iterator."):
+        _ = dataset_ops.make_initializable_iterator(dataset)
 
   @combinations.generate(
       combinations.times(
-          combinations.combine(tf_api_version=[1, 2], mode="eager"),
+          test_base.eager_only_combinations(),
           combinations.combine(execution_mode=[context.ASYNC, context.SYNC])))
   def testEagerIteration(self, execution_mode):
     with context.execution_mode(execution_mode):
diff --git a/tensorflow/python/data/kernel_tests/filter_test.py b/tensorflow/python/data/kernel_tests/filter_test.py
index 05b538a46ce..f6bdcb12020 100644
--- a/tensorflow/python/data/kernel_tests/filter_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_test.py
@@ -30,28 +30,31 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
-def new_and_legacy_filter_fn_combinations():
+def _test_combinations():
 
-  def new_filter_fn(dataset, predicate):
+  def filter_fn(dataset, predicate):
     return dataset.filter(predicate)
 
   def legacy_filter_fn(dataset, predicate):
     return dataset.filter_with_legacy_function(predicate)
 
-  return (combinations.combine(
+  filter_combinations = combinations.combine(
       tf_api_version=[1, 2],
       mode=["eager", "graph"],
-      apply_filter=combinations.NamedObject("new_filter_fn", new_filter_fn)) +
-          combinations.combine(
-              tf_api_version=1,
-              mode=["eager", "graph"],
-              apply_filter=combinations.NamedObject("legacy_filter_fn",
-                                                    legacy_filter_fn)))
+      apply_filter=combinations.NamedObject("filter_fn", filter_fn))
+
+  legacy_filter_combinations = combinations.combine(
+      tf_api_version=1,
+      mode=["eager", "graph"],
+      apply_filter=combinations.NamedObject("legacy_filter_fn",
+                                            legacy_filter_fn))
+
+  return filter_combinations + legacy_filter_combinations
 
 
 class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @combinations.generate(new_and_legacy_filter_fn_combinations())
+  @combinations.generate(_test_combinations())
   def testFilterDataset(self, apply_filter):
     components = (np.arange(7, dtype=np.int64),
                   np.array([[1, 2, 3]], dtype=np.int64) *
@@ -87,14 +90,14 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
     # Test an empty dataset.
     do_test(0, 1)
 
-  @combinations.generate(new_and_legacy_filter_fn_combinations())
+  @combinations.generate(_test_combinations())
   def testFilterRange(self, apply_filter):
     dataset = dataset_ops.Dataset.range(4)
     dataset = apply_filter(dataset,
                            lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
     self.assertDatasetProduces(dataset, expected_output=[0, 1, 3])
 
-  @combinations.generate(new_and_legacy_filter_fn_combinations())
+  @combinations.generate(_test_combinations())
   def testFilterDict(self, apply_filter):
     dataset = dataset_ops.Dataset.range(10).map(
         lambda x: {"foo": x * 2, "bar": x**2})
@@ -104,7 +107,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
         dataset,
         expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2])
 
-  @combinations.generate(new_and_legacy_filter_fn_combinations())
+  @combinations.generate(_test_combinations())
   def testUseStepContainerInFilter(self, apply_filter):
     input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
 
@@ -119,7 +122,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
     dataset = apply_filter(dataset, _predicate)
     self.assertDatasetProduces(dataset, expected_output=[input_data[0]])
 
-  @combinations.generate(new_and_legacy_filter_fn_combinations())
+  @combinations.generate(_test_combinations())
   def testSparse(self, apply_filter):
 
     def _map_fn(i):
@@ -137,7 +140,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)])
 
-  @combinations.generate(new_and_legacy_filter_fn_combinations())
+  @combinations.generate(_test_combinations())
   def testShortCircuit(self, apply_filter):
     dataset = dataset_ops.Dataset.zip(
         (dataset_ops.Dataset.range(10),
@@ -146,7 +149,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         dataset, expected_output=[(i, True) for i in range(10)])
 
-  @combinations.generate(new_and_legacy_filter_fn_combinations())
+  @combinations.generate(_test_combinations())
   def testParallelFilters(self, apply_filter):
     dataset = dataset_ops.Dataset.range(10)
     dataset = apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0))
diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py
index 3afed61fc7f..00b6e400ea7 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_test.py
@@ -66,10 +66,8 @@ class FlatMapTest(test_base.DatasetTestBase, parameterized.TestCase):
         expected_output.extend([i] * i)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
-  # Note: no eager mode coverage, session specific test.
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
-  def testSkipEagerSharedResourceNestedFlatMapDataset(self):
+  @combinations.generate(test_base.graph_only_combinations())
+  def testSharedResourceNestedFlatMapDataset(self):
     repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
     components = np.array(repeats, dtype=np.int64)
     iterator = (
diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py
index dfa467add62..49753babacb 100644
--- a/tensorflow/python/data/kernel_tests/from_generator_test.py
+++ b/tensorflow/python/data/kernel_tests/from_generator_test.py
@@ -32,62 +32,83 @@ from tensorflow.python.ops import script_ops
 from tensorflow.python.platform import test
 
 
-class DatasetConstructorTest(test_base.DatasetTestBase, parameterized.TestCase):
+class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def _testFromGenerator(self, generator, elem_sequence, num_repeats,
-                         output_types=None):
-    if output_types is None:
-      output_types = dtypes.int64
-    dataset = dataset_ops.Dataset.from_generator(
-        generator, output_types=output_types).repeat(num_repeats).prefetch(5)
-    self.assertDatasetProduces(
-        dataset,
-        elem_sequence * num_repeats,
-        requires_initialization=True,
-        num_test_iterations=2)
-
-  def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats):
+                         requires_initialization):
     dataset = dataset_ops.Dataset.from_generator(
         generator, output_types=dtypes.int64).repeat(num_repeats).prefetch(5)
     self.assertDatasetProduces(
-        dataset, elem_sequence * num_repeats, num_test_iterations=2)
+        dataset,
+        elem_sequence * num_repeats,
+        requires_initialization=requires_initialization,
+        num_test_iterations=2)
+
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              num_repeats=[1, 5], requires_initialization=[True, False])))
+  def testFromGeneratorUsingFn(self, num_repeats, requires_initialization):
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testFromGeneratorUsingFunction(self):
     def generator():
       for i in range(1, 100):
         yield [i] * i
-    elem_sequence = list(generator())
-    self._testFromGenerator(generator, elem_sequence, 1)
-    self._testFromGenerator(generator, elem_sequence, 5)
-    self._testFromGeneratorOneShot(generator, elem_sequence, 1)
-    self._testFromGeneratorOneShot(generator, elem_sequence, 5)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testFromGeneratorUsingList(self):
+    elem_sequence = list(generator())
+    self._testFromGenerator(
+        generator,
+        elem_sequence,
+        num_repeats=num_repeats,
+        requires_initialization=requires_initialization)
+
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              num_repeats=[1, 5], requires_initialization=[True, False])))
+  def testFromGeneratorUsingList(self, num_repeats, requires_initialization):
     generator = lambda: [[i] * i for i in range(1, 100)]
     elem_sequence = list(generator())
-    self._testFromGenerator(generator, elem_sequence, 1)
-    self._testFromGenerator(generator, elem_sequence, 5)
+    self._testFromGenerator(
+        generator,
+        elem_sequence,
+        num_repeats=num_repeats,
+        requires_initialization=requires_initialization)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testFromGeneratorUsingNdarray(self):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              num_repeats=[1, 5], requires_initialization=[True, False])))
+  def testFromGeneratorUsingNdarray(self, num_repeats, requires_initialization):
     generator = lambda: np.arange(100, dtype=np.int64)
     elem_sequence = list(generator())
-    self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64)
-    self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64)
+    self._testFromGenerator(
+        generator,
+        elem_sequence,
+        num_repeats=num_repeats,
+        requires_initialization=requires_initialization)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testFromGeneratorUsingGeneratorExpression(self):
-    # NOTE(mrry): Generator *expressions* are not repeatable (or in
-    # general reusable), because they eagerly evaluate the `for`
-    # expression as `iter(range(1, 100))` and discard the means of
-    # reconstructing `range(1, 100)`. Wrapping the generator
-    # expression in a `lambda` makes it repeatable.
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              num_repeats=[1, 5], requires_initialization=[True, False])))
+  def testFromGeneratorUsingGeneratorExpression(self, num_repeats,
+                                                requires_initialization):
+    # NOTE(mrry): Generator *expressions* are not repeatable (or in general
+    # reusable), because they eagerly evaluate the `for` expression as
+    # `iter(range(1, 100))` and discard the means of reconstructing
+    # `range(1, 100)`. Wrapping the generator expression in a `lambda` makes
+    # it repeatable.
     generator = lambda: ([i] * i for i in range(1, 100))
     elem_sequence = list(generator())
-    self._testFromGenerator(generator, elem_sequence, 1)
-    self._testFromGenerator(generator, elem_sequence, 5)
+    self._testFromGenerator(
+        generator,
+        elem_sequence,
+        num_repeats=num_repeats,
+        requires_initialization=requires_initialization)
 
   @combinations.generate(test_base.default_test_combinations())
   def testFromMultipleConcurrentGenerators(self):
@@ -392,7 +413,6 @@ class DatasetConstructorTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertAllEqual(37, self.evaluate(get_next()))
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
-      self.assertTrue(event.is_set())
 
   @combinations.generate(test_base.default_test_combinations())
   def testSharedName(self):
diff --git a/tensorflow/python/data/kernel_tests/from_tensors_test.py b/tensorflow/python/data/kernel_tests/from_tensors_test.py
index e293383403f..c899c156739 100644
--- a/tensorflow/python/data/kernel_tests/from_tensors_test.py
+++ b/tensorflow/python/data/kernel_tests/from_tensors_test.py
@@ -237,8 +237,8 @@ class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertEqual([3], get_next().shape)
 
   # TODO(b/121264236): needs mechanism for multiple device in eager mode.
-  @combinations.generate(test_base.default_test_combinations())
-  def testSkipEagerSplitPipeline(self):
+  @combinations.generate(test_base.graph_only_combinations())
+  def testSplitPipeline(self):
     with session.Session(
         target="",
         config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
diff --git a/tensorflow/python/data/kernel_tests/iterator_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_cluster_test.py
index 0384f9fc18a..0a40c212006 100644
--- a/tensorflow/python/data/kernel_tests/iterator_cluster_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_cluster_test.py
@@ -17,12 +17,15 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -37,9 +40,9 @@ from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
 
 
-class IteratorClusterTest(test.TestCase):
+class IteratorClusterTest(test.TestCase, parameterized.TestCase):
 
-  @test_util.run_v1_only("b/120545219")
+  @combinations.generate(test_base.graph_only_combinations())
   def testRemoteIteratorWithoutRemoteCallFail(self):
     worker_config = config_pb2.ConfigProto()
     worker_config.device_count["CPU"] = 2
@@ -95,7 +98,7 @@ class IteratorClusterTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(remote_op, feed_dict={target_placeholder: device1})
 
-  @test_util.run_v1_only("b/120545219")
+  @combinations.generate(test_base.graph_only_combinations())
   def testRemoteIteratorUsingRemoteCallOp(self):
     worker_config = config_pb2.ConfigProto()
     worker_config.device_count["CPU"] = 2
@@ -106,7 +109,7 @@ class IteratorClusterTest(test.TestCase):
                                    "/job:worker/replica:0/task:0/cpu:1",
                                    worker[0].target)
 
-  @test_util.run_v1_only("b/120545219")
+  @combinations.generate(test_base.graph_only_combinations())
   def testRemoteIteratorUsingRemoteCallOpCrossProcess(self):
     workers, _ = test_util.create_local_cluster(2, 1)
 
@@ -114,7 +117,7 @@ class IteratorClusterTest(test.TestCase):
                                    "/job:worker/replica:0/task:1/cpu:0",
                                    workers[0].target)
 
-  @test_util.run_v1_only("b/120545219")
+  @combinations.generate(test_base.graph_only_combinations())
   def testCaptureHashTableInSharedIterator(self):
     worker, _ = test_util.create_local_cluster(1, 1)
 
@@ -131,10 +134,10 @@ class IteratorClusterTest(test.TestCase):
     input_sentences = dataset_ops.Dataset.from_tensor_slices(
         ["brain brain tank salad surgery", "surgery brain"])
 
-    iterator = (
-        input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
-            table.lookup)
-        .make_initializable_iterator(shared_name="shared_iterator"))
+    dataset = input_sentences.map(
+        lambda x: string_ops.string_split([x]).values).map(table.lookup)
+    iterator = dataset_ops.make_initializable_iterator(
+        dataset, shared_name="shared_iterator")
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
@@ -148,7 +151,7 @@ class IteratorClusterTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
-  @test_util.run_v1_only("b/120545219")
+  @combinations.generate(test_base.graph_only_combinations())
   def testImplicitDisposeParallelMapDataset(self):
     # Tests whether a parallel map dataset will be cleaned up correctly when
     # the pipeline does not run it until exhaustion.
diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py
index 70f03f3e4d2..fcb2e4c0b1f 100644
--- a/tensorflow/python/data/kernel_tests/iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_test.py
@@ -56,8 +56,7 @@ from tensorflow.python.util import compat
 
 class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testNoGradients(self):
     component = constant_op.constant([1.])
     side = constant_op.constant(0.)
@@ -68,8 +67,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertIsNone(gradients_impl.gradients(value, side)[0])
     self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testCapturingStateInOneShotRaisesException(self):
     var = variables.Variable(37.0, name="myvar")
     dataset = (
@@ -80,8 +78,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
         "datasets that capture stateful objects.+myvar"):
       dataset_ops.make_one_shot_iterator(dataset)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testOneShotIterator(self):
     components = (np.arange(7),
                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
@@ -107,8 +104,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testOneShotIteratorCaptureByValue(self):
     components = (np.arange(7),
                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
@@ -172,8 +168,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
         with self.assertRaises(errors.OutOfRangeError):
           sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testOneShotIteratorNonBlocking(self):
     dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
     iterator = dataset_ops.make_one_shot_iterator(dataset)
@@ -207,13 +202,11 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
       for t in threads:
         t.join()
 
-      self.assertEqual(num_threads, len(results))
-      self.assertEqual(num_threads - 1,
-                       len([None for r in results if r is None]))
+      self.assertLen(results, num_threads)
+      self.assertLen([None for r in results if r is None], num_threads - 1)
       self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testOneShotIteratorInitializerFails(self):
     # Define a dataset whose initialization will always fail.
     dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
@@ -243,8 +236,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
       for t in threads:
         t.join()
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testSimpleSharedResource(self):
     components = (np.array(1, dtype=np.int64),
                   np.array([1, 2, 3], dtype=np.int64),
@@ -294,8 +286,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
         with self.assertRaises(errors.OutOfRangeError):
           sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testNotInitializedError(self):
     components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
     iterator = dataset_ops.make_initializable_iterator(
@@ -307,8 +298,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
                                    "iterator has not been initialized"):
         sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testReinitializableIterator(self):
     dataset_3 = dataset_ops.Dataset.from_tensors(
         constant_op.constant([1, 2, 3]))
@@ -353,8 +343,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testReinitializableIteratorWithFunctions(self):
 
     def g():
@@ -415,8 +404,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
               (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
                constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testIteratorStringHandle(self):
     dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
     dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
@@ -474,8 +462,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
         sess.run(
             next_element, feed_dict={handle_placeholder: iterator_4_handle})
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testIteratorStringHandleFuture(self):
     with forward_compat.forward_compatibility_horizon(2018, 8, 4):
       dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
@@ -541,8 +528,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
           sess.run(
               next_element, feed_dict={handle_placeholder: iterator_4_handle})
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testIteratorStringHandleReuseTensorObject(self):
     dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
     one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
@@ -571,8 +557,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertEqual("foo_1", handle_with_same_name.op.name)
     self.assertIsNot(handle_with_name, handle_with_same_name)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testIteratorStringHandleError(self):
     dataset_int_scalar = (
         dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
@@ -613,8 +598,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
             feedable_int_vector.get_next(),
             feed_dict={handle_placeholder: handle_float_vector}))
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
     worker_config = config_pb2.ConfigProto()
     worker_config.device_count["CPU"] = 3
@@ -672,8 +656,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
                 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
             })
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
     s1 = server_lib.Server.create_local_server()
     s2 = server_lib.Server.create_local_server()
@@ -727,8 +710,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(n)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
     if not test_util.is_gpu_available():
       self.skipTest("No GPU available")
@@ -785,8 +767,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
                 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
             })
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
+  @combinations.generate(test_base.graph_only_combinations())
   def testRepeatedGetNextWarning(self):
     iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10))
     warnings.simplefilter("always")
@@ -929,7 +910,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
         self.assertEqual(val, foo.numpy())
         val += 1
 
-  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testOwnedIteratorFunction(self):
 
     queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
@@ -946,7 +927,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
     for i in range(10):
       self.assertEqual(queue.dequeue().numpy(), i)
 
-  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testOwnedIteratorFunctionError(self):
     # In this test we verify that a function that raises an error ends up
     # properly deallocating the iterator resource.
@@ -976,7 +957,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     self.assertEqual(queue.size().numpy(), 2)
 
-  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testLimitedRetracing(self):
     trace_count = [0]
 
@@ -996,7 +977,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
       self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
       self.assertEqual(trace_count[0], 1)
 
-  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testNestedFunctionsIteratorResource(self):
 
     @def_function.function
diff --git a/tensorflow/python/data/kernel_tests/list_files_test.py b/tensorflow/python/data/kernel_tests/list_files_test.py
index 52ce300f537..40b4b77116c 100644
--- a/tensorflow/python/data/kernel_tests/list_files_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_test.py
@@ -35,10 +35,12 @@ from tensorflow.python.util import compat
 class ListFilesTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def setUp(self):
+    super(ListFilesTest, self).setUp()
     self.tmp_dir = tempfile.mkdtemp()
 
   def tearDown(self):
     shutil.rmtree(self.tmp_dir, ignore_errors=True)
+    super(ListFilesTest, self).tearDown()
 
   def _touchTempFiles(self, filenames):
     for filename in filenames:
diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py
index 0847cdd7a0d..c8b23edbc7f 100644
--- a/tensorflow/python/data/kernel_tests/map_test.py
+++ b/tensorflow/python/data/kernel_tests/map_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
 from collections import namedtuple
 import threading
 import time
@@ -31,13 +32,13 @@ from tensorflow.python.data.experimental.ops import threading_options
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import data_flow_ops
@@ -57,10 +58,70 @@ from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import test
 
 
-def _make_coordinated_sloppy_dataset(num_elements, num_parallel_calls):
+def _test_combinations_with_mode_v1(mode):
+
+  def new_map_fn(dataset, *args, **kwargs):
+    return dataset.map(*args, **kwargs)
+
+  def legacy_map_fn(dataset, *args, **kwargs):
+    return dataset.map_with_legacy_function(*args, **kwargs)
+
+  new_map_combinations = combinations.combine(
+      tf_api_version=1,
+      mode=mode,
+      apply_map=combinations.NamedObject("map_fn", new_map_fn))
+
+  legacy_map_combinations = combinations.combine(
+      tf_api_version=1,
+      mode=mode,
+      apply_map=combinations.NamedObject("legacy_map_fn", legacy_map_fn))
+
+  return new_map_combinations + legacy_map_combinations
+
+
+def _test_combinations_with_mode_v2(mode):
+
+  def new_map_fn(dataset, *args, **kwargs):
+    return dataset.map(*args, **kwargs)
+
+  return combinations.combine(
+      tf_api_version=2,
+      mode=mode,
+      apply_map=combinations.NamedObject("map_fn", new_map_fn))
+
+
+def _test_combinations_with_mode(mode):
+  return _test_combinations_with_mode_v1(
+      mode) + _test_combinations_with_mode_v2(mode)
+
+
+def _test_combinations():
+  return _test_combinations_with_mode("eager") + _test_combinations_with_mode(
+      "graph")
+
+
+def _short_circuit_test_cases():
+  cases = [
+      ("Identity", None, lambda x: x),
+      ("Replicate", None, lambda x: (x, x)),
+      ("Swap", (None, None), lambda x, y: (y, x)),
+      ("Project", (None, None), lambda x, y: x)
+  ]
+
+  def reduce_fn(x, y):
+    name, structure, fn = y
+    return x + combinations.combine(
+        structure=structure, fn=combinations.NamedObject(name, fn))
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
+def _make_coordinated_sloppy_dataset(apply_map, num_elements,
+                                     num_parallel_calls):
   """Produces a dataset iterator and events to control the order of elements.
 
   Args:
+    apply_map: method that applies the `map` transformation
     num_elements: the number of input elements
     num_parallel_calls: the degree of map parallelism
 
@@ -84,28 +145,27 @@ def _make_coordinated_sloppy_dataset(num_elements, num_parallel_calls):
 
   options = dataset_ops.Options()
   options.experimental_deterministic = False
-  dataset = dataset_ops.Dataset.range(num_elements).map(
-      map_fn, num_parallel_calls).with_options(options)
+  dataset = dataset_ops.Dataset.range(num_elements)
+  dataset = apply_map(dataset, map_fn, num_parallel_calls).with_options(options)
   return dataset, coordination_events
 
 
-# TODO(jsimsa): Add tests for `map_with_legacy_function`.
-@test_util.run_all_in_graph_and_eager_modes
 class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  def _buildMapDataset(self, components, count):
+  def _map_dataset_factory(self, components, apply_map, count):
 
     def _map_fn(x, y, z):
       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
 
-    dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
-        _map_fn).repeat(count)
+    dataset = dataset_ops.Dataset.from_tensor_slices(components)
+    dataset = apply_map(dataset, _map_fn).repeat(count)
     self.assertEqual(
         [c.shape[1:] for c in components],
         [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
     return dataset
 
-  def testMapDataset(self):
+  @combinations.generate(_test_combinations())
+  def testMapDataset(self, apply_map):
     """Test an dataset that maps a TF function across its input elements."""
     # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
     # RepeatDataset(count).
@@ -114,7 +174,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
                   np.array(37.0) * np.arange(7))
 
     # Test single-threaded access to the iterator.
-    get_next = self.getNext(self._buildMapDataset(components, 14))
+    get_next = self.getNext(
+        self._map_dataset_factory(components, apply_map, count=14))
     for _ in range(14):
       for i in range(7):
         result = self.evaluate(get_next())
@@ -123,15 +184,15 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  # TODO(b/117581999): add eager coverage, different threads run in graph
-  # context.
-  @test_util.run_v1_only("b/120545219")
-  def testSkipEagerMapDatasetMultithreaded(self):
+  # TODO(b/117581999): add eager coverage
+  @combinations.generate(_test_combinations_with_mode("graph"))
+  def testMapDatasetMultiThreaded(self, apply_map):
     # Test multi-threaded access to the same iterator.
     components = (np.arange(7),
                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                   np.array(37.0) * np.arange(7))
-    get_next = self.getNext(self._buildMapDataset(components, 18))
+    get_next = self.getNext(
+        self._map_dataset_factory(components, apply_map, count=18))
     results = []
     with self.cached_session() as sess:
       def iterator_thread():
@@ -157,94 +218,99 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
                                                  results[i * 18 + j]):
             self.assertAllEqual(component[i]**2, result_component)
 
-  def _buildParallelMapDataset(self, components, count, num_parallel_calls,
-                               output_buffer_size):
+  def _parallel_map_dataset_factory(self, components, apply_map, count,
+                                    num_parallel_calls, buffer_size):
 
     def _map_fn(x, y, z):
       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
 
-    dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
-        _map_fn, num_parallel_calls=num_parallel_calls).prefetch(
-            output_buffer_size).repeat(count)
+    dataset = dataset_ops.Dataset.from_tensor_slices(components)
+    dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls)
+    dataset = dataset.prefetch(buffer_size).repeat(count)
 
     self.assertEqual(
         [c.shape[1:] for c in components],
         [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
     return dataset
 
-  def testParallelMapDataset(self):
+  @combinations.generate(
+      combinations.times(
+          _test_combinations(),
+          combinations.combine(num_parallel_calls=1, buffer_size=1) +
+          combinations.combine(num_parallel_calls=1, buffer_size=2) +
+          combinations.combine(num_parallel_calls=2, buffer_size=2) +
+          combinations.combine(num_parallel_calls=2, buffer_size=4) +
+          combinations.combine(num_parallel_calls=8, buffer_size=8) +
+          combinations.combine(num_parallel_calls=8, buffer_size=16)))
+  def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size):
     """Test an dataset that maps a TF function across its input elements."""
 
     # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
     # RepeatDataset(count).
-    def do_test(num_parallel_calls, output_buffer_size):
+    components = (np.arange(7),
+                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+                  np.array(37.0) * np.arange(7))
+    # Test single-threaded access to the iterator.
+    get_next = self.getNext(
+        self._parallel_map_dataset_factory(components, apply_map, 14,
+                                           num_parallel_calls, buffer_size))
+    for _ in range(14):
+      for i in range(7):
+        result = self.evaluate(get_next())
+        for component, result_component in zip(components, result):
+          self.assertAllEqual(component[i]**2, result_component)
+    with self.assertRaises(errors.OutOfRangeError):
+      self.evaluate(get_next())
 
-      components = (np.arange(7),
-                    np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
-                    np.array(37.0) * np.arange(7))
-      # Test single-threaded access to the iterator.
-      get_next = self.getNext(
-          self._buildParallelMapDataset(components, 14, num_parallel_calls,
-                                        output_buffer_size))
-      for _ in range(14):
-        for i in range(7):
-          result = self.evaluate(get_next())
-          for component, result_component in zip(components, result):
+  # TODO(b/117581999): add eager coverage
+  @combinations.generate(
+      combinations.times(
+          _test_combinations_with_mode("graph"),
+          combinations.combine(num_parallel_calls=1, buffer_size=1) +
+          combinations.combine(num_parallel_calls=1, buffer_size=2) +
+          combinations.combine(num_parallel_calls=2, buffer_size=2) +
+          combinations.combine(num_parallel_calls=2, buffer_size=4) +
+          combinations.combine(num_parallel_calls=8, buffer_size=8) +
+          combinations.combine(num_parallel_calls=8, buffer_size=16)))
+  def testParallelMapDatasetMultiThreaded(self, apply_map, num_parallel_calls,
+                                          buffer_size):
+
+    # Test multi-threaded access to the same iterator.
+    components = (np.arange(7),
+                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+                  np.array(37.0) * np.arange(7))
+    get_next = self.getNext(
+        self._parallel_map_dataset_factory(components, apply_map, 18,
+                                           num_parallel_calls, buffer_size))
+    results = []
+    with self.cached_session() as sess:
+
+      def iterator_thread():
+        while True:
+          try:
+            results.append(sess.run(get_next()))
+          except errors.OutOfRangeError:
+            return
+
+      threads = [self.checkedThread(target=iterator_thread) for _ in range(64)]
+      for t in threads:
+        t.start()
+      for t in threads:
+        t.join()
+
+      # `results` will contain the same elements components**2
+      # repeated 18 times, but in a non-deterministic order. Sort the
+      # results, and assert that each element of components**2 is
+      # produced 18 times.
+      results.sort(key=lambda x: x[0])
+      for i in range(7):
+        for j in range(18):
+          for component, result_component in zip(components,
+                                                 results[i * 18 + j]):
             self.assertAllEqual(component[i]**2, result_component)
-      with self.assertRaises(errors.OutOfRangeError):
-        self.evaluate(get_next())
 
-    for num_parallel_calls_val, output_buffer_size_val in [(1, 1), (1, 2), (2,
-                                                                            2),
-                                                           (2, 4), (8, 8),
-                                                           (8, 16)]:
-      do_test(num_parallel_calls_val, output_buffer_size_val)
-
-  # TODO(b/117581999): add eager coverage, different threads run in graph
-  # context.
-  @test_util.run_v1_only("b/120545219")
-  def testSkipEagerParallelMapDatasetMultithreaded(self):
-
-    def do_test(num_parallel_calls, output_buffer_size):
-      # Test multi-threaded access to the same iterator.
-      components = (np.arange(7),
-                    np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
-                    np.array(37.0) * np.arange(7))
-      get_next = self.getNext(
-          self._buildParallelMapDataset(components, 18, num_parallel_calls,
-                                        output_buffer_size))
-      results = []
-      with self.cached_session() as sess:
-
-        def iterator_thread():
-          while True:
-            try:
-              results.append(sess.run(get_next()))
-            except errors.OutOfRangeError:
-              return
-        threads = [self.checkedThread(target=iterator_thread)
-                   for _ in range(64)]
-        for t in threads:
-          t.start()
-        for t in threads:
-          t.join()
-
-        # `results` will contain the same elements components**2
-        # repeated 18 times, but in a non-deterministic order. Sort the
-        # results, and assert that each element of components**2 is
-        # produced 18 times.
-        results.sort(key=lambda x: x[0])
-        for i in range(7):
-          for j in range(18):
-            for component, result_component in zip(components,
-                                                   results[i * 18 + j]):
-              self.assertAllEqual(component[i]**2, result_component)
-
-      for num_parallel_calls_val, output_buffer_size_val in [
-          (1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]:
-        do_test(num_parallel_calls_val, output_buffer_size_val)
-
-  def testImplicitDisposeParallelMapDataset(self):
+  @combinations.generate(_test_combinations())
+  def testImplicitDisposeParallelMapDataset(self, apply_map):
     # Tests whether a parallel map dataset will be cleaned up correctly when
     # the pipeline does not run it until exhaustion.
     # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
@@ -253,7 +319,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
                   np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
                   np.array(37.0) * np.arange(1000))
 
-    dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
+    dataset = self._parallel_map_dataset_factory(components, apply_map, 1000,
+                                                 100, 100)
     # NOTE(mrry): Also test that the prefetching thread is cancelled correctly.
     dataset = dataset.prefetch(100)
     get_next = self.getNext(dataset)
@@ -261,23 +328,29 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     for _ in range(3):
       self.evaluate(get_next())
 
-  def testParallelMapUnspecifiedOutputSize(self):
+  @combinations.generate(_test_combinations())
+  def testParallelMapUnspecifiedOutputSize(self, apply_map):
     components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
 
-    dataset = (dataset_ops.Dataset.from_tensor_slices(components)
-               .map(lambda x: array_ops.check_numerics(x, "message"),
-                    num_parallel_calls=2))
+    dataset = dataset_ops.Dataset.from_tensor_slices(components)
+    dataset = apply_map(
+        dataset,
+        lambda x: array_ops.check_numerics(x, "message"),
+        num_parallel_calls=2)
     get_next = self.getNext(dataset)
 
     for _ in range(3):
       self.evaluate(get_next())
 
-  def testParallelMapError(self):
+  @combinations.generate(_test_combinations())
+  def testParallelMapError(self, apply_map):
     components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
 
-    dataset = (dataset_ops.Dataset.from_tensor_slices(components)
-               .map(lambda x: array_ops.check_numerics(x, "message"),
-                    num_parallel_calls=2))
+    dataset = dataset_ops.Dataset.from_tensor_slices(components)
+    dataset = apply_map(
+        dataset,
+        lambda x: array_ops.check_numerics(x, "message"),
+        num_parallel_calls=2)
     get_next = self.getNext(dataset)
 
     for _ in range(3):
@@ -289,13 +362,13 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  def testPrefetchError(self):
+  @combinations.generate(_test_combinations())
+  def testPrefetchError(self, apply_map):
     components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
 
-    dataset = (dataset_ops.Dataset.from_tensor_slices(components)
-               .map(lambda x: array_ops.check_numerics(x, "message"))
-               .prefetch(2))
-
+    dataset = dataset_ops.Dataset.from_tensor_slices(components)
+    dataset = apply_map(
+        dataset, lambda x: array_ops.check_numerics(x, "message")).prefetch(2)
     get_next = self.getNext(dataset)
 
     for _ in range(3):
@@ -307,7 +380,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  def testCaptureIterator(self):
+  @combinations.generate(_test_combinations())
+  def testCaptureIterator(self, apply_map):
 
     def _build_ds(iterator):
 
@@ -315,7 +389,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
         get_next = iterator.get_next()
         return x * get_next
 
-      return dataset_ops.Dataset.range(10).map(_map_fn)
+      return apply_map(dataset_ops.Dataset.range(10), _map_fn)
 
     def _build_graph():
       if context.executing_eagerly():
@@ -335,7 +409,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  def testCaptureHashTable(self):
+  @combinations.generate(_test_combinations())
+  def testCaptureHashTable(self, apply_map):
     # NOTE(mrry): We must use the V2 variants of `HashTable`
     # etc. because these produce a `tf.resource`-typed output that is
     # compatible with the in-graph function implementation.
@@ -348,8 +423,9 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     input_sentences = dataset_ops.Dataset.from_tensor_slices(
         ["brain brain tank salad surgery", "surgery brain"])
 
-    dataset = input_sentences.map(lambda x: string_ops.string_split([x]).values
-                                 ).map(table.lookup)
+    dataset = apply_map(input_sentences,
+                        lambda x: string_ops.string_split([x]).values)
+    dataset = apply_map(dataset, table.lookup)
 
     get_next = self.getNext(dataset, requires_initialization=True)
 
@@ -359,14 +435,15 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  @test_util.run_v1_only("b/123904513")
-  def testCaptureQueue(self):
+  # TODO(b/123904513)
+  @combinations.generate(_test_combinations_with_mode_v1("graph"))
+  def testCaptureQueue(self, apply_map):
     elements = np.random.randint(100, size=[200])
     queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
     enqueue_op = queue.enqueue_many(elements)
     close_op = queue.close()
-    dataset = dataset_ops.Dataset.from_tensors(0).repeat(
-        -1).map(lambda _: queue.dequeue())
+    dataset = dataset_ops.Dataset.from_tensors(0).repeat(-1)
+    dataset = apply_map(dataset, lambda _: queue.dequeue())
 
     get_next = self.getNext(dataset, requires_initialization=True)
     self.evaluate(enqueue_op)
@@ -378,8 +455,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       self.evaluate(get_next())
 
   # TODO(b/117581999): Possible deadlock in eager mode, debug.
-  @test_util.run_v1_only("b/120545219")
-  def testSkipEagerCaptureSameResourceMultipleTimes(self):
+  @combinations.generate(_test_combinations_with_mode_v1("graph"))
+  def testCaptureSameResourceMultipleTimes(self, apply_map):
     elements = np.random.randint(100, size=[200])
     queue = data_flow_ops.FIFOQueue(
         200, dtypes.int64, shapes=[], shared_name="shared_queue")
@@ -389,8 +466,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     enqueue_op = queue.enqueue_many(elements)
     close_op = queue.close()
 
-    dataset = dataset_ops.Dataset.from_tensors(0).repeat(
-        -1).map(lambda _: (queue.dequeue(), queue_2.dequeue()))
+    dataset = dataset_ops.Dataset.from_tensors(0).repeat(-1)
+    dataset = apply_map(dataset, lambda _: (queue.dequeue(), queue_2.dequeue()))
 
     self.evaluate(enqueue_op)
     self.evaluate(close_op)
@@ -401,9 +478,11 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  def testSeededStatefulOperatorIsProperlyStateful(self):
-    dataset = dataset_ops.Dataset.from_tensors(0).repeat(
-        10).map(lambda _: random_ops.random_uniform((), seed=11)).batch(2)
+  @combinations.generate(_test_combinations())
+  def testSeededStatefulOperatorIsProperlyStateful(self, apply_map):
+    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
+    fn = lambda _: random_ops.random_uniform((), seed=11)
+    dataset = apply_map(dataset, fn).batch(2)
 
     get_next = self.getNext(dataset, requires_initialization=True)
     random_values = []
@@ -422,9 +501,11 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     # Randomness is repeatable given same seed
     self.assertAllClose(random_values, random_values_2)
 
-  def testStatefulMapKeepsStateAcrossIterators(self):
-    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(
-        lambda _: random_ops.random_uniform((), seed=11)).repeat(1000).batch(10)
+  @combinations.generate(_test_combinations())
+  def testStatefulMapKeepsStateAcrossIterators(self, apply_map):
+    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
+    fn = lambda _: random_ops.random_uniform((), seed=11)
+    dataset = apply_map(dataset, fn).repeat(1000).batch(10)
 
     get_next = self.getNext(dataset)
     random_values = self.evaluate(get_next())
@@ -438,7 +519,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       i += 1
     self.assertLess(i, 99)
 
-  def testStatefulOperationInShortCircuit(self):
+  @combinations.generate(_test_combinations())
+  def testStatefulOperationInShortCircuit(self, apply_map):
     counter_var = variable_scope.get_variable(
         "counter", (), dtypes.int32, use_resource=True)
 
@@ -446,7 +528,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       counter_var.assign_add(1)
       return x
 
-    dataset = dataset_ops.Dataset.range(10).map(increment_fn)
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, increment_fn)
 
     get_next = self.getNext(dataset, requires_initialization=True)
 
@@ -459,22 +542,24 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       self.evaluate(get_next())
     self.assertEqual(10, self.evaluate(counter_var))
 
-  def testMapDict(self):
-    dataset = dataset_ops.Dataset.range(10).map(
-        lambda x: {"foo": x * 2, "bar": x**2}).map(
-            lambda d: d["foo"] + d["bar"])
+  @combinations.generate(_test_combinations())
+  def testMapDict(self, apply_map):
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, lambda x: {"foo": x * 2, "bar": x**2})
+    dataset = apply_map(dataset, lambda d: d["foo"] + d["bar"])
     self.assertDatasetProduces(
         dataset, expected_output=[i * 2 + i**2 for i in range(10)])
 
-  def testMapNamedtuple(self, count=10):
+  @combinations.generate(_test_combinations())
+  def testMapNamedtuple(self, apply_map):
     # construct dataset of tuples
-    labels = dataset_ops.Dataset.range(count)
-    images = labels.map(lambda l: -l)
+    labels = dataset_ops.Dataset.range(10)
+    images = apply_map(labels, lambda l: -l)
     dataset_tuple = dataset_ops.Dataset.zip((labels, images))
 
     # convert dataset of tuples to dataset of namedtuples
     example = namedtuple("Example", ["label", "image"])
-    dataset_namedtuple = dataset_tuple.map(example)
+    dataset_namedtuple = apply_map(dataset_tuple, example)
 
     def preprocess_tuple(label, image):
       image = 2 * image
@@ -484,14 +569,14 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       return example._replace(image=2 * example.image)
 
     # preprocess both datasets
-    dataset_tuple = dataset_tuple.map(preprocess_tuple)
-    dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple)
+    dataset_tuple = apply_map(dataset_tuple, preprocess_tuple)
+    dataset_namedtuple = apply_map(dataset_namedtuple, preprocess_namedtuple)
 
     next_tuple = self.getNext(dataset_tuple)
     next_namedtuple = self.getNext(dataset_namedtuple)
 
     # make sure both datasets contain the same data
-    for i in range(count):
+    for i in range(10):
       tuple_, namedtuple_ = self.evaluate([next_tuple(), next_namedtuple()])
       self.assertEqual(tuple_, namedtuple_)
       self.assertEqual(tuple_, (i, -2 * i))
@@ -499,13 +584,16 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_namedtuple())
 
-  def testUseStepContainerInMap(self):
+  @combinations.generate(_test_combinations())
+  def testUseStepContainerInMap(self, apply_map):
     row = np.arange(6)
-    dataset = dataset_ops.Dataset.from_tensors(
-        row).map(lambda elems: map_fn.map_fn(lambda x: x * x, elems))
+    dataset = dataset_ops.Dataset.from_tensors(row)
+    dataset = apply_map(dataset,
+                        lambda elems: map_fn.map_fn(lambda x: x * x, elems))
     self.assertDatasetProduces(dataset, expected_output=[row**2])
 
-  def testCaseAndCondInMap(self):
+  @combinations.generate(_test_combinations())
+  def testCaseAndCondInMap(self, apply_map):
 
     def control_map_fn(x, y):
 
@@ -531,13 +619,12 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
           pred_fn_pairs, default=multiply, exclusive=True)
 
     def build_dataset(row, num):
-      dataset = dataset_ops.Dataset.from_tensor_slices(
-          row).map(lambda x: control_map_fn(x, num))
-      return self.getNext(dataset)
+      dataset = dataset_ops.Dataset.from_tensor_slices(row)
+      return apply_map(dataset, lambda x: control_map_fn(x, num))
 
     row = np.arange(6)
     for num in [2, 3, 4]:
-      get_next = build_dataset(row, num)
+      get_next = self.getNext(build_dataset(row, num))
       for i in range(6):
         self.assertEqual(
             (i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2,
@@ -545,7 +632,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         self.evaluate(get_next())
 
-  def testCaseInWhileInMap(self):
+  @combinations.generate(_test_combinations())
+  def testCaseInWhileInMap(self, apply_map):
 
     def control_map_fn(x, y):
 
@@ -564,22 +652,22 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
           pred_fn_pairs, default=multiply, exclusive=True)
 
     def build_dataset(row, num):
-      # pylint: disable=g-long-lambda
-      dataset = dataset_ops.Dataset.from_tensors(
-          row).map(lambda elems: map_fn.map_fn(
-              lambda x: control_map_fn(x, num), elems))
-      return self.getNext(dataset)
+      dataset = dataset_ops.Dataset.from_tensors(row)
+      return apply_map(
+          dataset,
+          lambda elems: map_fn.map_fn(lambda x: control_map_fn(x, num), elems))
 
     row = np.arange(6)
     for num in [2, 3, 4]:
-      get_next = build_dataset(row, num)
+      get_next = self.getNext(build_dataset(row, num))
       self.assertAllEqual(
           [x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
           self.evaluate(get_next()))
       with self.assertRaises(errors.OutOfRangeError):
         self.evaluate(get_next())
 
-  def testCaseAndCondInWhileInMap(self):
+  @combinations.generate(_test_combinations())
+  def testCaseAndCondInWhileInMap(self, apply_map):
 
     def control_map_fn(x, y):
 
@@ -606,11 +694,10 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     row = np.arange(6)
     num = 2
-    # pylint: disable=g-long-lambda
-    dataset = dataset_ops.Dataset.from_tensors(
-        row).map(lambda elems: map_fn.map_fn(
-            lambda x: control_map_fn(x, num), elems))
-    # pylint: enable=g-long-lambda
+    dataset = dataset_ops.Dataset.from_tensors(row)
+    dataset = apply_map(
+        dataset,
+        lambda elems: map_fn.map_fn(lambda x: control_map_fn(x, num), elems))
     get_next = self.getNext(dataset)
 
     self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
@@ -619,17 +706,20 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  def testNestedListMapDataset(self):
-    dataset = dataset_ops.Dataset.from_tensors(
-        [0, 1, 2]).repeat(10).map(lambda a: ([a[1], a[0] + a[2]], a[1]))
-
+  @combinations.generate(_test_combinations())
+  def testNestedListMapDataset(self, apply_map):
+    dataset = dataset_ops.Dataset.from_tensors([0, 1, 2]).repeat(10)
+    dataset = apply_map(dataset, lambda a: ([a[1], a[0] + a[2]], a[1]))
     expected_output = [(np.array([1, 2]), 1)] * 10
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
-  def testPrefetch(self):
-    # We will use this event to test that `_map_py_func()` has been
-    # invoked a certain number of times (6 times, to be exact) after
-    # consuming fewer elements from the iterator.
+  @combinations.generate(
+      combinations.times(_test_combinations(),
+                         combinations.combine(buffer_size=[1, 2, 3, 4])))
+  def testPrefetch(self, apply_map, buffer_size):
+    # We will use this event to test that `_map_py_func()` has been invoked a
+    # certain number of times (6 times, to be exact) after consuming fewer
+    # elements from the iterator.
     ev = threading.Event()
 
     set_event_during_invocation = 5
@@ -642,56 +732,38 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     def _map_fn(x):
       return script_ops.py_func(_map_py_func, [x], x.dtype)
 
-    def do_test(buffer_size):
-      dataset = dataset_ops.Dataset.range(100).map(_map_fn).prefetch(
-          buffer_size)
-
-      get_next = self.getNext(dataset)
-      # Simple test that prefetch yields the expected values in the
-      # expected order.
-      for i in range(100):
-        self.assertEqual(i * i, self.evaluate(get_next()))
-      with self.assertRaises(errors.OutOfRangeError):
-        self.evaluate(get_next())
-
-    for buffer_size in [1, 10, 100, 1000]:
-      do_test(buffer_size)
-
-    # We can indirectly observe that varying the buffer size has the
-    # intended effect by observing when `ev` is set (on the 6th
-    # invocation of `_map_py_func()`).
+    # We can indirectly observe that varying the buffer size has the intended
+    # effect by observing when `ev` is set (on the 6th invocation of
+    # `_map_py_func()`).
     # NOTE(mrry): We do not test with `buffer_size ==
-    # set_event_during_invocation`, because we must consume at least
-    # one element to start the prefetching.
-    def do_test_ev(buffer_size):
-      dataset = dataset_ops.Dataset.range(100).map(_map_fn).prefetch(
-          buffer_size)
+    # set_event_during_invocation`, because we must consume at least one element
+    # to start the prefetching.
+    dataset = dataset_ops.Dataset.range(100)
+    dataset = apply_map(dataset, _map_fn).prefetch(buffer_size)
+    get_next = self.getNext(dataset)
 
-      get_next = self.getNext(dataset)
+    event_will_be_set_after_consuming = (
+        set_event_during_invocation - buffer_size + 1)
 
-      event_will_be_set_after_consuming = (
-          set_event_during_invocation - buffer_size + 1)
+    ev.clear()
+    for i in range(event_will_be_set_after_consuming):
+      self.assertFalse(ev.is_set())
+      self.assertEqual(i * i, self.evaluate(get_next()))
+    ev.wait()
+    for i in range(event_will_be_set_after_consuming, 100):
+      self.assertEqual(i * i, self.evaluate(get_next()))
+    with self.assertRaises(errors.OutOfRangeError):
+      self.evaluate(get_next())
 
-      ev.clear()
-      for i in range(event_will_be_set_after_consuming):
-        self.assertFalse(ev.is_set())
-        self.assertEqual(i * i, self.evaluate(get_next()))
-      ev.wait()
-      for i in range(event_will_be_set_after_consuming, 100):
-        self.assertEqual(i * i, self.evaluate(get_next()))
-      with self.assertRaises(errors.OutOfRangeError):
-        self.evaluate(get_next())
-
-    for buffer_size in range(1, set_event_during_invocation):
-      do_test_ev(buffer_size)
-
-  def testReturnList(self):
-    dataset = dataset_ops.Dataset.range(
-        10).map(lambda x: [x, constant_op.constant(37.0)])
+  @combinations.generate(_test_combinations())
+  def testReturnList(self, apply_map):
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, lambda x: [x, constant_op.constant(37.0)])
     self.assertDatasetProduces(
         dataset, expected_output=[(i, 37.0) for i in range(10)])
 
-  def testMultiOutputPyFunc(self):
+  @combinations.generate(_test_combinations())
+  def testMultiOutputPyFunc(self, apply_map):
     # The `tf.py_func()` op returns a list of tensors for its outputs.
     def _map_fn(x_tensor):
       def _map_py_func(x):
@@ -699,11 +771,13 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       return script_ops.py_func(
           _map_py_func, [x_tensor], [dtypes.int64, dtypes.float64])
 
-    dataset = dataset_ops.Dataset.range(10).map(_map_fn)
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, _map_fn)
     self.assertDatasetProduces(
         dataset, expected_output=[(i, 37.0) for i in range(10)])
 
-  def testSparse(self):
+  @combinations.generate(_test_combinations())
+  def testSparse(self, apply_map):
 
     def _sparse(i):
       return sparse_tensor.SparseTensorValue(
@@ -711,11 +785,13 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
           values=(i * np.array([1])),
           dense_shape=np.array([1, 1]))
 
-    dataset = dataset_ops.Dataset.range(10).map(_sparse)
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, _sparse)
     self.assertDatasetProduces(
         dataset, expected_output=[_sparse(i) for i in range(10)])
 
-  def testSparseChain(self):
+  @combinations.generate(_test_combinations())
+  def testSparseChain(self, apply_map):
 
     def _sparse(i):
       return sparse_tensor.SparseTensorValue(
@@ -727,37 +803,38 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       self.assertTrue(sparse_tensor.is_sparse(i))
       return sparse_ops.sparse_concat(0, [i, i])
 
-    dataset = dataset_ops.Dataset.range(10).map(_sparse).map(_check)
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, _sparse)
+    dataset = apply_map(dataset, _check)
 
     self.assertDatasetProduces(
         dataset,
         expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)])
 
-  def testSparseMapShapeInference(self):
-    if not context.executing_eagerly():
-      self.skipTest("SparseTensor shape inference requires eager mode")
+  @combinations.generate(_test_combinations_with_mode("eager"))
+  def testSparseMapShapeInference(self, apply_map):
     row_lengths = np.random.randint(0, 4, size=128)
     values = np.ones(np.sum(row_lengths))
     sparse = ragged_tensor.RaggedTensor.from_row_lengths(
         values, row_lengths).to_sparse()
     dataset = dataset_ops.Dataset.from_tensor_slices(sparse)
     dataset = dataset.batch(32, drop_remainder=True)
-    dataset = dataset.map(lambda x: x)
+    dataset = apply_map(dataset, lambda x: x)
     self.assertEqual((32, 3), dataset.element_spec.shape)
 
-  def testSparseMapShapeInferencePartial(self):
-    if not context.executing_eagerly():
-      self.skipTest("SparseTensor shape inference requires eager mode")
+  @combinations.generate(_test_combinations_with_mode("eager"))
+  def testSparseMapShapeInferencePartial(self, apply_map):
     row_lengths = np.random.randint(0, 4, size=128)
     values = np.ones(np.sum(row_lengths))
     sparse = ragged_tensor.RaggedTensor.from_row_lengths(
         values, row_lengths).to_sparse()
     dataset = dataset_ops.Dataset.from_tensor_slices(sparse)
     dataset = dataset.batch(32, drop_remainder=False)
-    dataset = dataset.map(lambda x: x)
+    dataset = apply_map(dataset, lambda x: x)
     self.assertEqual([None, 3], dataset.element_spec.shape.as_list())
 
-  def testTensorArray(self):
+  @combinations.generate(_test_combinations())
+  def testTensorArray(self, apply_map):
 
     def _tensor_array(i):
       i = math_ops.cast(i, dtypes.int32)
@@ -765,11 +842,13 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
           tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i)
           .unstack(math_ops.range(i, dtype=dtypes.int32)))
 
-    dataset = dataset_ops.Dataset.range(10).map(_tensor_array)
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, _tensor_array)
     self.assertDatasetProduces(
         dataset, expected_output=[list(range(i)) for i in range(10)])
 
-  def testTensorArrayChain(self):
+  @combinations.generate(_test_combinations())
+  def testTensorArrayChain(self, apply_map):
 
     def _tensor_array(i):
       i = math_ops.cast(i, dtypes.int32)
@@ -781,23 +860,28 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       self.assertIsInstance(x, tensor_array_ops.TensorArray)
       return x.identity()
 
-    dataset = dataset_ops.Dataset.range(10).map(_tensor_array).map(_check)
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, _tensor_array)
+    dataset = apply_map(dataset, _check)
 
     self.assertDatasetProduces(
         dataset,
         expected_output=[list(range(i)) for i in range(10)])
 
-  def testRagged(self):
+  @combinations.generate(_test_combinations())
+  def testRagged(self, apply_map):
 
     def _ragged(i):
       return ragged_tensor.RaggedTensor.from_tensor(i * [[1]])
 
-    dataset = dataset_ops.Dataset.range(5).map(_ragged)
+    dataset = dataset_ops.Dataset.range(5)
+    dataset = apply_map(dataset, _ragged)
     self.assertDatasetProduces(
         dataset,
         expected_output=[ragged_factory_ops.constant([[i]]) for i in range(5)])
 
-  def testRaggedChain(self):
+  @combinations.generate(_test_combinations())
+  def testRaggedChain(self, apply_map):
 
     def _ragged(i):
       return ragged_tensor.RaggedTensor.from_tensor(i * [[1]])
@@ -806,7 +890,9 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       self.assertTrue(ragged_tensor.is_ragged(i))
       return ragged_concat_ops.concat([i, i], 0)
 
-    dataset = dataset_ops.Dataset.range(10).map(_ragged).map(_concat)
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, _ragged)
+    dataset = apply_map(dataset, _concat)
 
     self.assertDatasetProduces(
         dataset,
@@ -815,15 +901,19 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
             for i in range(10)
         ])
 
-  @test_util.run_v1_only("b/123904513")
-  def testParallelMapOutOfRangeError(self):
+  # TODO(b/123904513)
+  @combinations.generate(_test_combinations_with_mode_v1("graph"))
+  def testParallelMapOutOfRangeError(self, apply_map):
+
     def raising_py_func(i):
       if i == 100:
         raise StopIteration()
       else:
         return i
 
-    dataset = dataset_ops.Dataset.range(105).map(
+    dataset = dataset_ops.Dataset.range(105)
+    dataset = apply_map(
+        dataset,
         lambda x: script_ops.py_func(raising_py_func, [x], dtypes.int64),
         num_parallel_calls=2)
     get_next = self.getNext(dataset)
@@ -832,11 +922,15 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  def testConstantOutput(self):
-    dataset = dataset_ops.Dataset.range(10).map(lambda x: [x, "hello", 10])
+  @combinations.generate(_test_combinations())
+  def testConstantOutput(self, apply_map):
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = apply_map(dataset, lambda x: [x, "hello", 10])
     self.assertDatasetProduces(dataset, [(i, b"hello", 10) for i in range(10)])
 
-  def testWarnOnLookupTable(self):
+  @combinations.generate(_test_combinations())
+  def testWarnOnLookupTable(self, apply_map):
+
     def collecting_function(x):
       _ = lookup_ops.HashTable(
           lookup_ops.KeyValueTensorInitializer(["a"], [1.]), 0.0, name="t1")
@@ -844,30 +938,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     warnings.simplefilter("always")
     with warnings.catch_warnings(record=True) as w:
-      _ = dataset_ops.Dataset.range(10).map(collecting_function)
-    # NOTE(mrry): Python 3 prints other warnings in addition to the one we are
-    # testing, so we search for the expected warning.
-    self.assertGreaterEqual(len(w), 1)
-    found_warning = False
-    for warning in w:
-      if ("Creating resources inside a function passed to Dataset.map() is "
-          "not supported." in str(warning)):
-        found_warning = True
-        break
-    self.assertTrue(found_warning)
-
-  @test_util.run_v1_only("map_with_legacy_function v1 only")
-  def testWarnOnLookupTableLegacyFunction(self):
-
-    def collecting_function(x):
-      _ = lookup_ops.HashTable(
-          lookup_ops.KeyValueTensorInitializer(["a"], [1.]), 0.0, name="t1")
-      return x
-
-    warnings.simplefilter("always")
-    with warnings.catch_warnings(record=True) as w:
-      _ = dataset_ops.Dataset.range(10).map_with_legacy_function(
-          collecting_function)
+      dataset = dataset_ops.Dataset.range(10)
+      _ = apply_map(dataset, collecting_function)
     # NOTE(mrry): Python 3 prints other warnings in addition to the one we are
     # testing, so we search for the expected warning.
     self.assertGreaterEqual(len(w), 1)
@@ -879,21 +951,25 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
         break
     self.assertTrue(found_warning)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWarnOnSeedFromOuterGraph(self):
     with ops.Graph().as_default() as g:
       g.seed = 10
       warnings.simplefilter("always")
 
+      def _check_warning(caught_warnings, expected_result):
+        found_warning = False
+        for warning in caught_warnings:
+          if ("Explicitly set the seed in the function if this is not the "
+              "intended behavior" in str(warning)):
+            found_warning = True
+            break
+        self.assertEqual(found_warning, expected_result)
+
       # map_fun doesn't use seed, so no warning is generated.
       with warnings.catch_warnings(record=True) as w:
         _ = dataset_ops.Dataset.range(10).map(math_ops.square)
-      found_warning = False
-      for warning in w:
-        if ("Explicitly set the seed in the function if this is not the "
-            "intended behavior" in str(warning)):
-          found_warning = True
-          break
-      self.assertFalse(found_warning)
+      _check_warning(w, False)
 
       def random_func(x):
         x = math_ops.add(x, 1)
@@ -902,14 +978,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
 
       with warnings.catch_warnings(record=True) as w:
         _ = dataset_ops.Dataset.range(10).map(random_func)
-      self.assertGreaterEqual(len(w), 1)
-      found_warning = False
-      for warning in w:
-        if ("Explicitly set the seed in the function if this is not the "
-            "intended behavior" in str(warning)):
-          found_warning = True
-          break
-      self.assertTrue(found_warning)
+      _check_warning(w, True)
 
       def random_func_seeded(x):
         ops.get_default_graph().seed = None
@@ -918,41 +987,30 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
 
       with warnings.catch_warnings(record=True) as w:
         _ = dataset_ops.Dataset.range(10).batch(2).map(random_func_seeded)
-      found_warning = False
-      for warning in w:
-        if ("Explicitly set the seed in the function if this is not the "
-            "intended behavior" in str(warning)):
-          found_warning = True
-          break
-      self.assertFalse(found_warning)
+      _check_warning(w, False)
 
       with warnings.catch_warnings(record=True) as w:
-        _ = dataset_ops.Dataset.range(10).batch(
-            2).map(lambda x: random_ops.random_shuffle(x, seed=37))
-      found_warning = False
-      for warning in w:
-        if ("Explicitly set the seed in the function if this is not the "
-            "intended behavior" in str(warning)):
-          found_warning = True
-          break
-      self.assertFalse(found_warning)
+        _ = dataset_ops.Dataset.range(10).batch(2).map(
+            lambda x: random_ops.random_shuffle(x, seed=37))
+      _check_warning(w, False)
 
-  def testNestedDatasetMap(self):
-    # TODO(b/110122868): When iterators can yield a `tf.data.Dataset`, remove
-    # the `get_single_element()` call.
-    dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]).map(
-        dataset_ops.Dataset.from_tensor_slices).map(
-            lambda ds: ds.batch(3)).flat_map(lambda x: x)
+  @combinations.generate(_test_combinations())
+  def testNestedDatasetMap(self, apply_map):
+    dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
+    dataset = apply_map(dataset, dataset_ops.Dataset.from_tensor_slices)
+    dataset = apply_map(dataset, lambda ds: ds.batch(3)).flat_map(lambda x: x)
 
     self.assertDatasetProduces(dataset, expected_output=[[1.0, 2.0, 3.0]])
 
-  def testReturnValueError(self):
+  @combinations.generate(_test_combinations())
+  def testReturnValueError(self, apply_map):
     dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
     with self.assertRaisesRegexp(
         TypeError, r"Unsupported return value from function passed to "
         r"Dataset.map\(\): None."):
-      _ = dataset.map(lambda x: None)
+      _ = apply_map(dataset, lambda x: None)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testBrokenFunctionErrorOnInitialization(self):
     dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])
 
@@ -965,8 +1023,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
               value, dtype=dtypes.float32, shape=[0], verify_shape=False))
       dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum)
 
-      # Create a "Const" op with a `tf.float32` value and a `tf.int32` type
-      # attr.
+      # Create a "Const" op with a `tf.float32` value and a `tf.int32` type.
       const_tensor = ops.get_default_graph().create_op(
           "Const", [], [dtypes.int32],
           attrs={
@@ -980,15 +1037,11 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         dataset, expected_error=(errors.InvalidArgumentError, "BrokenConst"))
 
-# pylint: disable=g-long-lambda
-  @parameterized.named_parameters(
-      ("Map", lambda dataset, func:
-       dataset_ops.MapDataset(dataset, func, use_inter_op_parallelism=False)),
-      ("ParallelMap", lambda dataset, func:
-       dataset_ops.ParallelMapDataset(dataset, func, num_parallel_calls=1,
-                                      use_inter_op_parallelism=False)),
-  )
-  def testNoInterOpParallelism(self, make_dataset_fn):
+  @combinations.generate(
+      combinations.times(
+          _test_combinations_with_mode("graph"),
+          combinations.combine(num_parallel_calls=[None, 12])))
+  def testNoInterOpParallelism(self, apply_map, num_parallel_calls):
     dataset = dataset_ops.Dataset.from_tensors(0)
 
     def _get_tid():
@@ -1000,58 +1053,54 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
         tids.append(script_ops.py_func(_get_tid, [], dtypes.int64))
       return tids
 
-    dataset = make_dataset_fn(dataset, _map_fn)
+    dataset = apply_map(dataset, _map_fn)
+    dataset._variant_tensor.op._set_attr("use_inter_op_parallelism",
+                                         attr_value_pb2.AttrValue(b=False))
     get_next = self.getNext(dataset)
 
     tids = self.evaluate(get_next())
     self.assertTrue(all(tids[0] == tid for tid in tids))
-# pylint: enable=g-long-lambda
 
-  @parameterized.named_parameters(
-      ("SequentialIdentity", None, lambda x: x, None),
-      ("SequentialReplicate", None, lambda x: (x, x), None),
-      ("SequentialSwap", (None, None), lambda x, y: (y, x), None),
-      ("SequentialProject", (None, None), lambda x, y: x, None),
-      ("ParallelIdentity", None, lambda x: x, 10),
-      ("ParallelReplicate", None, lambda x: (x, x), 10),
-      ("ParallelSwap", (None, None), lambda x, y: (y, x), 10),
-      ("ParallelProject", (None, None), lambda x, y: x, 10),
-  )
-  def testShortCircuit(self, structure, map_fn, num_parallel_calls):
-    dataset = self.structuredDataset(structure).repeat().map(
-        map_fn, num_parallel_calls=num_parallel_calls)
+  @combinations.generate(
+      combinations.times(_test_combinations(), _short_circuit_test_cases(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testShortCircuit(self, apply_map, structure, fn, num_parallel_calls):
+    dataset = self.structuredDataset(structure).repeat()
+    dataset = apply_map(dataset, fn, num_parallel_calls=num_parallel_calls)
     get_next = self.getNext(dataset)
 
     if isinstance(structure, tuple):
-      expected = map_fn(*self.evaluate(self.structuredElement(structure)))
+      expected = fn(*self.evaluate(self.structuredElement(structure)))
     else:
-      expected = map_fn(self.evaluate(self.structuredElement(structure)))
+      expected = fn(self.evaluate(self.structuredElement(structure)))
     self.assertEqual(expected, self.evaluate(get_next()))
 
-  @parameterized.named_parameters(
-      ("Sequential", None),
-      ("Parallel", 10),
-  )
-  def testShortCircuitCapturedInput(self, num_parallel_calls):
+  @combinations.generate(
+      combinations.times(_test_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 12])))
+  def testShortCircuitCapturedInput(self, apply_map, num_parallel_calls):
     captured_t = variables.Variable(42)
-    dataset = self.structuredDataset(None).repeat().map(
-        lambda x: captured_t, num_parallel_calls=num_parallel_calls)
+    dataset = self.structuredDataset(None).repeat()
+    dataset = apply_map(
+        dataset, lambda x: captured_t, num_parallel_calls=num_parallel_calls)
     self.evaluate(variables.global_variables_initializer())
     get_next = self.getNext(dataset, requires_initialization=True)
 
     self.assertEqual(42, self.evaluate(get_next()))
 
-  @parameterized.named_parameters(
-      ("1", 1, 1),
-      ("2", 10, 1),
-      ("3", 10, 10),
-      ("4", 100, 1),
-      ("5", 100, 10),
-      ("6", 100, 100),
-  )
-  def testSloppyInterleaveInOrder(self, num_elements, num_parallel_calls):
+  @combinations.generate(
+      combinations.times(
+          _test_combinations(),
+          combinations.combine(num_elements=1, num_parallel_calls=1) +
+          combinations.combine(num_elements=10, num_parallel_calls=1) +
+          combinations.combine(num_elements=10, num_parallel_calls=10) +
+          combinations.combine(num_elements=100, num_parallel_calls=1) +
+          combinations.combine(num_elements=100, num_parallel_calls=10) +
+          combinations.combine(num_elements=100, num_parallel_calls=100)))
+  def testSloppyInterleaveInOrder(self, apply_map, num_elements,
+                                  num_parallel_calls):
     dataset, coordination_events = _make_coordinated_sloppy_dataset(
-        num_elements, num_parallel_calls)
+        apply_map, num_elements, num_parallel_calls)
     options = dataset_ops.Options()
     options.experimental_threading = threading_options.ThreadingOptions()
     options.experimental_threading.private_threadpool_size = (
@@ -1064,14 +1113,16 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  @parameterized.named_parameters(
-      ("1", 10, 10),
-      ("2", 100, 10),
-      ("3", 100, 100),
-  )
-  def testSloppyInterleaveOutOfOrder(self, num_elements, num_parallel_calls):
+  @combinations.generate(
+      combinations.times(
+          _test_combinations(),
+          combinations.combine(num_elements=10, num_parallel_calls=10) +
+          combinations.combine(num_elements=100, num_parallel_calls=10) +
+          combinations.combine(num_elements=100, num_parallel_calls=100)))
+  def testSloppyInterleaveOutOfOrder(self, apply_map, num_elements,
+                                     num_parallel_calls):
     dataset, coordination_events = _make_coordinated_sloppy_dataset(
-        num_elements, num_parallel_calls)
+        apply_map, num_elements, num_parallel_calls)
     options = dataset_ops.Options()
     options.experimental_threading = threading_options.ThreadingOptions()
     options.experimental_threading.private_threadpool_size = (
@@ -1090,25 +1141,25 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  @parameterized.named_parameters(
-      ("Map", None),
-      ("ParallelMap", 12),
-  )
+  @combinations.generate(
+      combinations.combine(
+          tf_api_version=2,
+          mode=["eager", "graph"],
+          num_parallel_calls=[None, 12]))
   def testPreserveCardinality(self, num_parallel_calls):
 
     def py_fn(_):
       raise StopIteration()
 
-    dataset = dataset_ops.DatasetV2.from_tensors(0).map(
+    dataset = dataset_ops.Dataset.from_tensors(0).map(
         lambda x: script_ops.py_func(py_fn, [x], dtypes.int64),
         num_parallel_calls=num_parallel_calls)
     get_next = self.getNext(dataset)
     with self.assertRaises(errors.InvalidArgumentError):
       self.evaluate(get_next())
 
-  # NOTE: collection test is specific to graph mode only, no eager coverage.
-  @test_util.run_v1_only("graph specific test")
-  def testSkipEagerCollectionCopy(self):
+  @combinations.generate(_test_combinations_with_mode("graph"))
+  def testCollectionCopy(self, apply_map):
     w = variable_scope.get_variable("w", [])
     self.assertIn(w, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
 
@@ -1117,22 +1168,21 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       return x
 
     dataset = dataset_ops.Dataset.from_tensors(constant_op.constant(1.0))
-    dataset.map(func)
+    _ = apply_map(dataset, func)
 
-  @parameterized.named_parameters(
-      ("Sequential", None),
-      ("Parallel", 12),
-  )
-  @test_util.run_v1_only("graph-mode specific test")
-  def testSkipEagerMapCancellation(self, num_parallel_calls):
+  @combinations.generate(
+      combinations.times(
+          _test_combinations_with_mode_v1("graph"),
+          combinations.combine(num_parallel_calls=[None, 12])))
+  def testMapCancellation(self, apply_map, num_parallel_calls):
     # Checks that a cancellation of is threaded through to map transformation.
     queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
 
     def fn(_):
       return queue.dequeue()
 
-    dataset = dataset_ops.Dataset.range(1).map(
-        fn, num_parallel_calls=num_parallel_calls)
+    dataset = dataset_ops.Dataset.range(1)
+    dataset = apply_map(dataset, fn, num_parallel_calls=num_parallel_calls)
     get_next = self.getNext(dataset, requires_initialization=True)
 
     with self.cached_session() as sess:
@@ -1143,17 +1193,11 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
       thread.join()
 
 
-# TODO(shivaniagarwal): separate out `map` and `map_with_legacy_function` tests
-# as later would not work in v2.
-@test_util.run_all_in_graph_and_eager_modes
-class MapWithCapturedVariableTests(test_base.DatasetTestBase,
-                                   parameterized.TestCase):
-
   # TODO(b/126553094): map doesnt work with variable defined inside function in
   # eager mode, possible Graph tensors leak out of the function building context
   # from function graph in eager mode as variables are created in init_scope.
-  @test_util.run_v1_only("b/126553094")
-  def testSkipEagerCreateVariableInsideFunctionWithGetter(self):
+  @combinations.generate(test_base.graph_only_combinations())
+  def testCreateVariableInsideFunctionWithGetter(self):
 
     def func(_):
       with variable_scope.variable_scope(
@@ -1162,12 +1206,13 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase,
             "counter", (), dtypes.int32, use_resource=True)
       return counter_var.assign_add(1)
 
-    # NOTE: In the legacy function, resource is captured by value for variable
-    # getter.
     dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
-    with self.assertRaisesWithPredicateMatch(
-        AttributeError, "'Tensor' object has no attribute 'assign_add'"):
-      dataset.map_with_legacy_function(func)
+
+    if hasattr(dataset, "map_with_legacy_function"):
+      # NOTE: In the legacy function, resource is captured by value.
+      with self.assertRaisesWithPredicateMatch(
+          AttributeError, "'Tensor' object has no attribute 'assign_add'"):
+        dataset.map_with_legacy_function(func)
 
     dataset = dataset.map(func)
     self.evaluate(variables.global_variables_initializer())
@@ -1179,18 +1224,12 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase,
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  @parameterized.named_parameters(
-      ("MapLegacyFunction",
-       lambda dataset, func: dataset.map_with_legacy_function(func)),
-      ("Map", lambda dataset, func: dataset.map(func)),
-  )
-  @test_util.run_v1_only("map_with_legacy_function is only available in v1.")
-  def testCaptureVariable(self, transformation_function):
+  @combinations.generate(_test_combinations())
+  def testCaptureVariable(self, apply_map):
     counter_var = variable_scope.get_variable(
         "counter", (), dtypes.int32, use_resource=True)
     dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
-    dataset = transformation_function(
-        dataset, lambda _: counter_var.assign_add(1))
+    dataset = apply_map(dataset, lambda _: counter_var.assign_add(1))
     get_next = self.getNext(dataset, requires_initialization=True)
 
     self.evaluate(counter_var.initializer)
@@ -1203,34 +1242,20 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase,
       self.evaluate(get_next())
     self.assertEqual(10, self.evaluate(counter_var))
 
-  # NOTE: no need to explicitly initialize variables in eager mode.
-  @parameterized.named_parameters(
-      ("MapLegacyFunction",
-       lambda dataset, func: dataset.map_with_legacy_function(func)),
-      ("Map", lambda dataset, func: dataset.map(func)),
-  )
-  @test_util.run_v1_only("this test is meant to run in graph mode only.")
-  def testSkipEagerCaptureUninitializedVariableError(self,
-                                                     transformation_function):
+  @combinations.generate(_test_combinations_with_mode_v1("graph"))
+  def testCaptureUninitializedVariableError(self, apply_map):
     counter_var = variable_scope.get_variable(
         "counter", (), dtypes.int32, use_resource=True)
     dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
-    dataset = transformation_function(
-        dataset, lambda _: counter_var.assign_add(1))
+    dataset = apply_map(dataset, lambda _: counter_var.assign_add(1))
 
     get_next = self.getNext(dataset, requires_initialization=True)
     with self.assertRaises(errors.NotFoundError):
       self.evaluate(get_next())
 
   # TODO(b/121264236): add eager mode coverage when we have multi-device setup.
-  @parameterized.named_parameters(
-      ("MapLegacyFunction",
-       lambda dataset, func: dataset.map_with_legacy_function(func)),
-      ("Map", lambda dataset, func: dataset.map(func)),
-  )
-  @test_util.run_v1_only("b/121264236")
-  def testSkipEagerCaptureConstantsWithConflictingDevices(
-      self, transformation_function):
+  @combinations.generate(_test_combinations_with_mode_v1("graph"))
+  def testCaptureConstantsWithConflictingDevices(self, apply_map):
     config = config_pb2.ConfigProto(device_count={"CPU": 3})
     with self.cached_session(config=config):
       with ops.device("/device:CPU:0"):
@@ -1242,13 +1267,13 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase,
         return math_ops.add(a, b)
 
       dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
-      dataset = transformation_function(dataset, func)
+      dataset = apply_map(dataset, func)
       expected_output = [8.0] * 10
       self.assertDatasetProduces(dataset, expected_output=expected_output)
 
   # TODO(b/121264236): add eager mode coverage when we have multi-device setup.
-  @test_util.run_v1_only("b/121264236")
-  def testSkipEagerRefVariablesWithMultipleDevices(self):
+  @combinations.generate(_test_combinations_with_mode_v1("graph"))
+  def testReferenceVariablesWithMultipleDevices(self, apply_map):
     config = config_pb2.ConfigProto(device_count={"CPU": 3})
     with self.cached_session(config=config):
 
@@ -1262,7 +1287,7 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase,
       # NOTE: Use the legacy function implementation as eager function will
       # convert RefVariables to ResourceVariables.
       dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
-      dataset = dataset.map_with_legacy_function(func)
+      dataset = apply_map(dataset, func)
       self.evaluate(variables.global_variables_initializer())
       expected_output = [8.0] * 10
       self.assertDatasetProduces(
@@ -1271,8 +1296,8 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase,
           requires_initialization=True)
 
   # TODO(b/121264236): add eager mode coverage when we have multi-device setup.
-  @test_util.run_v1_only("b/121264236")
-  def testSkipEagerResourceVariablesWithMultipleDevices(self):
+  @combinations.generate(_test_combinations_with_mode_v1("graph"))
+  def testResourceVariablesWithMultipleDevices(self, apply_map):
     config = config_pb2.ConfigProto(device_count={"CPU": 3})
 
     def func(_):
@@ -1287,25 +1312,10 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase,
               "b", (), dtypes.int32, use_resource=True)
       return math_ops.add(a_var, b_var)
 
-    g_1 = ops.Graph()
-    with self.session(config=config, graph=g_1):
-      # The MapDataset node ends up with two ResourceVariable inputs, one on
-      # device CPU:0 and the other on device CPU:1.
+    g = ops.Graph()
+    with self.session(config=config, graph=g):
       dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
-      dataset = dataset.map(func)
-      self.evaluate(variables.global_variables_initializer())
-      expected_output = [1] * 10
-      self.assertDatasetProduces(
-          dataset,
-          expected_output=expected_output,
-          requires_initialization=True)
-
-    g_2 = ops.Graph()
-    with self.session(config=config, graph=g_2):
-      # In old-Defun variable is captured as value, hence there is no colocation
-      # error.
-      dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
-      dataset = dataset.map_with_legacy_function(func)
+      dataset = apply_map(dataset, func)
       self.evaluate(variables.global_variables_initializer())
       expected_output = [1] * 10
       self.assertDatasetProduces(
diff --git a/tensorflow/python/data/kernel_tests/memory_cleanup_test.py b/tensorflow/python/data/kernel_tests/memory_cleanup_test.py
index a2015ef47d1..5b0ea02a054 100644
--- a/tensorflow/python/data/kernel_tests/memory_cleanup_test.py
+++ b/tensorflow/python/data/kernel_tests/memory_cleanup_test.py
@@ -119,8 +119,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
     ]
     self.assertEmpty(tensors, "%d Tensors are still alive." % len(tensors))
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testFilter(self):
 
     def get_dataset():
@@ -144,8 +143,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     self._testIteratorMemoryLeak(get_dataset)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testFlatMap(self):
 
     def get_dataset():
@@ -157,8 +155,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     self._testIteratorMemoryLeak(get_dataset)
 
-  @combinations.generate(
-      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  @combinations.generate(test_base.eager_only_combinations())
   def testFromGenerator(self):
 
     def get_dataset():
@@ -171,8 +168,8 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
     self._testIteratorMemoryLeak(get_dataset)
 
   @combinations.generate(
-      combinations.combine(
-          tf_api_version=[1, 2], mode="eager", num_parallel_calls=[None, 10]))
+      combinations.times(test_base.eager_only_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 10])))
   def testMap(self, num_parallel_calls):
 
     def get_dataset():
@@ -201,8 +198,8 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
     self._testIteratorMemoryLeak(get_dataset)
 
   @combinations.generate(
-      combinations.combine(
-          tf_api_version=[1, 2], mode="eager", num_parallel_calls=[None, 10]))
+      combinations.times(test_base.eager_only_combinations(),
+                         combinations.combine(num_parallel_calls=[None, 10])))
   def testInterleave(self, num_parallel_calls):
 
     def get_dataset():
diff --git a/tensorflow/python/data/kernel_tests/optional_test.py b/tensorflow/python/data/kernel_tests/optional_test.py
index 3ab6717b9c3..f0795563d09 100644
--- a/tensorflow/python/data/kernel_tests/optional_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_test.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
+
 from absl.testing import parameterized
 import numpy as np
 
@@ -27,6 +29,7 @@ from tensorflow.python.data.ops import optional_ops
 from tensorflow.python.data.util import structure
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -40,14 +43,90 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
+def _optional_spec_test_combinations():
+  # pylint: disable=g-long-lambda
+  cases = [
+      ("Dense", lambda: constant_op.constant(37.0),
+       tensor_spec.TensorSpec([], dtypes.float32)),
+      ("Sparse", lambda: sparse_tensor.SparseTensor(
+          indices=[[0, 1]],
+          values=constant_op.constant([0], dtype=dtypes.int32),
+          dense_shape=[10, 10]),
+       sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
+      ("Nest", lambda: {
+          "a": constant_op.constant(37.0),
+          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
+      }, {
+          "a":
+              tensor_spec.TensorSpec([], dtypes.float32),
+          "b": (
+              tensor_spec.TensorSpec([1], dtypes.string),
+              tensor_spec.TensorSpec([], dtypes.string),
+          )
+      }),
+      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
+       optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32))),
+  ]
+
+  def reduce_fn(x, y):
+    name, value_fn, expected_structure = y
+    return x + combinations.combine(
+        tf_value_fn=combinations.NamedObject(name, value_fn),
+        expected_value_structure=expected_structure)
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
+def _get_next_as_optional_test_combinations():
+  # pylint: disable=g-long-lambda
+  cases = [
+      ("Dense", np.array([1, 2, 3], dtype=np.int32),
+       lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
+      ("Sparse",
+       sparse_tensor.SparseTensorValue(
+           indices=[[0, 0], [1, 1]],
+           values=np.array([-1., 1.], dtype=np.float32),
+           dense_shape=[2, 2]),
+       lambda: sparse_tensor.SparseTensor(
+           indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
+       False),
+      ("Nest", {
+          "a":
+              np.array([1, 2, 3], dtype=np.int32),
+          "b":
+              sparse_tensor.SparseTensorValue(
+                  indices=[[0, 0], [1, 1]],
+                  values=np.array([-1., 1.], dtype=np.float32),
+                  dense_shape=[2, 2])
+      }, lambda: {
+          "a":
+              constant_op.constant([4, 5, 6], dtype=dtypes.int32),
+          "b":
+              sparse_tensor.SparseTensor(
+                  indices=[[0, 1], [1, 0]],
+                  values=[37.0, 42.0],
+                  dense_shape=[2, 2])
+      }, False),
+  ]
+
+  def reduce_fn(x, y):
+    name, value, value_fn, gpu_compatible = y
+    return x + combinations.combine(
+        np_value=value, tf_value_fn=combinations.NamedObject(name, value_fn),
+        gpu_compatible=gpu_compatible)
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
 class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testFromValue(self):
     opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
     self.assertTrue(self.evaluate(opt.has_value()))
     self.assertEqual(37.0, self.evaluate(opt.get_value()))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testFromStructuredValue(self):
     opt = optional_ops.Optional.from_value({
         "a": constant_op.constant(37.0),
@@ -59,6 +138,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
         "b": ([b"Foo"], b"Bar")
     }, self.evaluate(opt.get_value()))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testFromSparseTensor(self):
     st_0 = sparse_tensor.SparseTensorValue(
         indices=np.array([[0]]),
@@ -77,6 +157,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
       self.assertAllEqual(expected.dense_shape,
                           self.evaluate(actual.dense_shape))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testFromNone(self):
     value_structure = tensor_spec.TensorSpec([], dtypes.float32)
     opt = optional_ops.Optional.none_from_structure(value_structure)
@@ -91,6 +172,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.InvalidArgumentError):
       self.evaluate(opt.get_value())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testAddN(self):
     devices = ["/cpu:0"]
     if test_util.is_gpu_available():
@@ -117,6 +199,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
                                              opt_none1.value_structure)
         self.assertFalse(self.evaluate(add_opt.has_value()))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNestedAddN(self):
     devices = ["/cpu:0"]
     if test_util.is_gpu_available():
@@ -137,6 +220,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
                                                    opt1.value_structure)
         self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testZerosLike(self):
     devices = ["/cpu:0"]
     if test_util.is_gpu_available():
@@ -159,6 +243,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
                                                opt_none.value_structure)
         self.assertFalse(self.evaluate(zeros_opt.has_value()))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNestedZerosLike(self):
     devices = ["/cpu:0"]
     if test_util.is_gpu_available():
@@ -175,6 +260,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
                                                      opt1.value_structure)
         self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testCopyToGPU(self):
     if not test_util.is_gpu_available():
       self.skipTest("No GPU available")
@@ -204,6 +290,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
                      self.evaluate(gpu_optional_with_value_values))
     self.assertFalse(self.evaluate(gpu_optional_none_has_value))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNestedCopyToGPU(self):
     if not test_util.is_gpu_available():
       self.skipTest("No GPU available")
@@ -239,42 +326,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertFalse(self.evaluate(inner_none.has_value()))
     self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
 
-  def _assertElementValueEqual(self, expected, actual):
-    if isinstance(expected, dict):
-      self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
-      for k in expected.keys():
-        self._assertElementValueEqual(expected[k], actual[k])
-    elif isinstance(expected, sparse_tensor.SparseTensorValue):
-      self.assertAllEqual(expected.indices, actual.indices)
-      self.assertAllEqual(expected.values, actual.values)
-      self.assertAllEqual(expected.dense_shape, actual.dense_shape)
-    else:
-      self.assertAllEqual(expected, actual)
-
-  # pylint: disable=g-long-lambda
-  @parameterized.named_parameters(
-      ("Tensor", lambda: constant_op.constant(37.0),
-       tensor_spec.TensorSpec([], dtypes.float32)),
-      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
-          indices=[[0, 1]],
-          values=constant_op.constant([0], dtype=dtypes.int32),
-          dense_shape=[10, 10]),
-       sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
-      ("Nest", lambda: {
-          "a": constant_op.constant(37.0),
-          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
-      }, {
-          "a":
-              tensor_spec.TensorSpec([], dtypes.float32),
-          "b": (
-              tensor_spec.TensorSpec([1], dtypes.string),
-              tensor_spec.TensorSpec([], dtypes.string),
-          )
-      }),
-      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
-       optional_ops.OptionalSpec(
-           tensor_spec.TensorSpec([], dtypes.float32))),
-  )
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          _optional_spec_test_combinations()))
   def testOptionalSpec(self, tf_value_fn, expected_value_structure):
     tf_value = tf_value_fn()
     opt = optional_ops.Optional.from_value(tf_value)
@@ -304,36 +359,21 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
     round_trip_opt = opt_structure._from_tensor_list(
         opt_structure._to_tensor_list(opt))
     if isinstance(tf_value, optional_ops.Optional):
-      self._assertElementValueEqual(
+      self.assertValuesEqual(
           self.evaluate(tf_value.get_value()),
           self.evaluate(round_trip_opt.get_value().get_value()))
     else:
-      self._assertElementValueEqual(
+      self.assertValuesEqual(
           self.evaluate(tf_value),
           self.evaluate(round_trip_opt.get_value()))
 
-  @parameterized.named_parameters(
-      ("Tensor", np.array([1, 2, 3], dtype=np.int32),
-       lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
-      ("SparseTensor", sparse_tensor.SparseTensorValue(
-          indices=[[0, 0], [1, 1]],
-          values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
-       lambda: sparse_tensor.SparseTensor(
-           indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
-       False),
-      ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
-                "b": sparse_tensor.SparseTensorValue(
-                    indices=[[0, 0], [1, 1]],
-                    values=np.array([-1., 1.], dtype=np.float32),
-                    dense_shape=[2, 2])},
-       lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
-                "b": sparse_tensor.SparseTensor(
-                    indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
-                    dense_shape=[2, 2])}, False),
-  )
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          _get_next_as_optional_test_combinations()))
   def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
-                                    works_on_gpu):
-    if not works_on_gpu and test.is_gpu_available():
+                                    gpu_compatible):
+    if not gpu_compatible and test.is_gpu_available():
       self.skipTest("Test case not yet supported on GPU.")
     ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
 
@@ -348,7 +388,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
             next_elem.value_structure,
             structure.type_spec_from_value(tf_value_fn())))
         self.assertTrue(next_elem.has_value())
-        self._assertElementValueEqual(np_value, next_elem.get_value())
+        self.assertValuesEqual(np_value, next_elem.get_value())
       # After exhausting the iterator, `next_elem.has_value()` will evaluate to
       # false, and attempting to get the value will fail.
       for _ in range(2):
@@ -379,7 +419,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
         elem_has_value, elem_value = self.evaluate(
             [elem_has_value_t, elem_value_t])
         self.assertTrue(elem_has_value)
-        self._assertElementValueEqual(np_value, elem_value)
+        self.assertValuesEqual(np_value, elem_value)
 
       # After exhausting the iterator, `next_elem.has_value()` will evaluate to
       # false, and attempting to get the value will fail.
@@ -388,6 +428,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
         with self.assertRaises(errors.InvalidArgumentError):
           self.evaluate(elem_value_t)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testFunctionBoundaries(self):
     @def_function.function
     def get_optional():
@@ -407,6 +448,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
     val = consume_optional(opt_tensor)
     self.assertEqual(self.evaluate(val), 1.0)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testLimitedRetracing(self):
     trace_count = [0]
 
diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py
index 222d8c6f1a6..b38d008b833 100644
--- a/tensorflow/python/data/kernel_tests/options_test.py
+++ b/tensorflow/python/data/kernel_tests/options_test.py
@@ -18,25 +18,31 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.experimental.ops import optimization_options
 from tensorflow.python.data.experimental.ops import stats_options
 from tensorflow.python.data.experimental.ops import threading_options
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 
 
-class OptionsTest(test_base.DatasetTestBase):
+class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptionsDefault(self):
     ds = dataset_ops.Dataset.range(0)
     self.assertEqual(dataset_ops.Options(), ds.options())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptionsOnce(self):
     options = dataset_ops.Options()
     ds = dataset_ops.Dataset.range(0).with_options(options).cache()
     self.assertEqual(options, ds.options())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptionsTwiceSame(self):
     options = dataset_ops.Options()
     options.experimental_optimization.autotune = True
@@ -44,6 +50,7 @@ class OptionsTest(test_base.DatasetTestBase):
         options)
     self.assertEqual(options, ds.options())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptionsTwiceDifferent(self):
     options1 = dataset_ops.Options()
     options1.experimental_optimization.autotune = True
@@ -55,6 +62,7 @@ class OptionsTest(test_base.DatasetTestBase):
     # Explicitly check that flag is False since assertFalse allows None
     self.assertIs(ds.options().experimental_deterministic, False)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptionsTwiceDifferentError(self):
     options1 = dataset_ops.Options()
     options1.experimental_optimization.autotune = True
@@ -64,6 +72,7 @@ class OptionsTest(test_base.DatasetTestBase):
                                  "Cannot merge incompatible values"):
       dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptionsMergeOptionsFromMultipleInputs(self):
     options1 = dataset_ops.Options()
     options1.experimental_optimization.autotune = True
@@ -75,6 +84,7 @@ class OptionsTest(test_base.DatasetTestBase):
     self.assertTrue(ds.options().experimental_optimization.autotune)
     self.assertTrue(ds.options().experimental_deterministic)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOptionsHaveDefaults(self):
     options1 = dataset_ops.Options()
     options2 = dataset_ops.Options()
@@ -84,12 +94,11 @@ class OptionsTest(test_base.DatasetTestBase):
                      options2.experimental_stats)
     self.assertIsNot(options1.experimental_threading,
                      options2.experimental_threading)
-    self.assertEquals(options1.experimental_optimization,
-                      optimization_options.OptimizationOptions())
-    self.assertEquals(options1.experimental_stats,
-                      stats_options.StatsOptions())
-    self.assertEquals(options1.experimental_threading,
-                      threading_options.ThreadingOptions())
+    self.assertEqual(options1.experimental_optimization,
+                     optimization_options.OptimizationOptions())
+    self.assertEqual(options1.experimental_stats, stats_options.StatsOptions())
+    self.assertEqual(options1.experimental_threading,
+                     threading_options.ThreadingOptions())
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/padded_batch_test.py b/tensorflow/python/data/kernel_tests/padded_batch_test.py
index 39339c0063a..a3b8f3945f3 100644
--- a/tensorflow/python/data/kernel_tests/padded_batch_test.py
+++ b/tensorflow/python/data/kernel_tests/padded_batch_test.py
@@ -23,43 +23,30 @@ import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
 
 
-def _random_seq_lens(count):
-  return np.random.randint(20, size=(count,)).astype(np.int32)
-
-
-@test_util.run_all_in_graph_and_eager_modes
 class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @parameterized.named_parameters(
-      ('default_padding', _random_seq_lens(32), 4, [-1], False),
-      ('constant_padding', _random_seq_lens(32), 4, [25], False),
-      ('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False),
-      ('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True),
-  )
-  def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes,
-                             drop_remainder):
-    """Tests the padded batch dataset logic for various input configurations.
-
-    Args:
-      seq_lens: the input sequence lengths
-      batch_size: the batch size
-      padded_shapes: the padded shapes to use
-      drop_remainder: whether a smaller batch size should be produced if batch
-        size does not divide number of inputs evenly
-    """
-
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              count=[32, 34],
+              padded_shapes=[[None], [25]],
+              drop_remainder=[True, False])))
+  def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder):
+    seq_lens = np.random.randint(20, size=(count,)).astype(np.int32)
+    batch_size = 4
     dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
         lambda x: array_ops.fill([x], x)).padded_batch(
             batch_size=batch_size,
@@ -81,7 +68,9 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     if not drop_remainder and len(seq_lens) % batch_size > 0:
       result = self.evaluate(get_next())
-      padded_len = np.max(result) if result.size > 0 else 0
+      padded_len = padded_shapes[0]
+      if padded_len is None or padded_len == -1:
+        padded_len = np.max(result) if result.size > 0 else 0
       self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape)
       for j in range(len(seq_lens) % batch_size):
         seq_len = seq_lens[num_full_batches * batch_size + j]
@@ -93,7 +82,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  @test_util.run_deprecated_v1
+  @combinations.generate(test_base.default_test_combinations())
   def testPaddedBatchShortPadding(self):
     dataset = (
         dataset_ops.Dataset.from_tensor_slices(
@@ -102,6 +91,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         dataset, expected_error=(errors.DataLossError, ''))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testPaddedBatchEmptyTensors(self):
     dataset = (
         dataset_ops.Dataset.from_tensor_slices(
@@ -109,6 +99,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
                 batch_size=4, padded_shapes=[-1]))
     self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testPaddedBatchDatasetNonDefaultPadding(self):
 
     def fill_tuple(x):
@@ -139,6 +130,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testPaddedBatchDatasetUnicode(self):
     # See GitHub issue 16149
     def generator():
@@ -156,9 +148,8 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     next_element = self.getNext(padded_dataset)
     self.evaluate(next_element())
 
-  # NOTE: This test is specific to graph mode and is skipped in eager mode.
-  @test_util.run_deprecated_v1
-  def testSkipEagerPaddedBatchDatasetShapeSpecifications(self):
+  @combinations.generate(test_base.graph_only_combinations())
+  def testPaddedBatchDatasetShapeSpecifications(self):
     int_placeholder = array_ops.placeholder(dtypes.int32)
     float_placeholder = array_ops.placeholder(dtypes.float32)
     string_placeholder = array_ops.placeholder(dtypes.string)
@@ -190,6 +181,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
       self.assertEqual([None, None, None], dataset_output_shapes[1].as_list())
       self.assertEqual([None, 37], dataset_output_shapes[2].as_list())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testPaddedBatchSparseError(self):
 
     def _map_fn(i):
@@ -199,6 +191,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(TypeError):
       _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testPaddedBatchShapeError(self):
     with self.assertRaisesRegexp(
         ValueError, r'The padded shape \(1,\) is not compatible with the '
@@ -230,9 +223,8 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
       _ = dataset_ops.Dataset.range(10).padded_batch(
           5, padded_shapes=shape_as_tensor)
 
-  # NOTE: This test is specific to graph mode and is skipped in eager mode.
-  @test_util.run_deprecated_v1
-  def testSkipEagerPaddedBatchShapeError(self):
+  @combinations.generate(test_base.graph_only_combinations())
+  def testPaddedBatchShapeErrorPlaceholder(self):
     with self.assertRaisesRegexp(
         ValueError,
         r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '
diff --git a/tensorflow/python/data/kernel_tests/prefetch_test.py b/tensorflow/python/data/kernel_tests/prefetch_test.py
index 427fbf1d29f..c6d2877ee7c 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_test.py
@@ -23,36 +23,41 @@ from absl.testing import parameterized
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @parameterized.parameters((-1), (0), (5))
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(buffer_size=[-1, None, 0, 42])))
   def testBufferSize(self, buffer_size):
     dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
     self.assertDatasetProduces(dataset, expected_output=range(10))
 
-  @parameterized.parameters((-2), (-42))
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(buffer_size=[-2, -42])))
   def testInvalidBufferSize(self, buffer_size):
     with self.assertRaises(errors.InvalidArgumentError):
       dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
       self.evaluate(dataset._variant_tensor)
 
-  @parameterized.parameters(*[(buffer_size, slack_period)
-                              for buffer_size in (-1, None, 0, 5)
-                              for slack_period in (1, 8)])
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              buffer_size=[-1, None, 0, 42], slack_period=[1, 8])))
   def testPrefetchWithSlack(self, buffer_size, slack_period):
     dataset = dataset_ops.Dataset.range(100)
     dataset = dataset_ops.PrefetchDataset(
         dataset, buffer_size, slack_period=slack_period)
     self.assertDatasetProduces(dataset, expected_output=range(100))
 
-  @test_util.run_v1_only("graph-mode specific test")
-  def testSkipEagerPrefetchCancellation(self):
+  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
+  def testPrefetchCancellation(self):
 
     def map_py_fn(x):
       while x > -1:
diff --git a/tensorflow/python/data/kernel_tests/range_test.py b/tensorflow/python/data/kernel_tests/range_test.py
index b7ac60c3fff..d136565ce42 100644
--- a/tensorflow/python/data/kernel_tests/range_test.py
+++ b/tensorflow/python/data/kernel_tests/range_test.py
@@ -17,51 +17,60 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class RangeTest(test_base.DatasetTestBase):
+class RangeTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testStop(self):
     dataset = dataset_ops.Dataset.range(5)
     self.assertDatasetProduces(dataset, expected_output=range(5))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testStartStop(self):
     start, stop = 2, 5
     dataset = dataset_ops.Dataset.range(start, stop)
     self.assertDatasetProduces(dataset, expected_output=range(2, 5))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testStartStopStep(self):
     start, stop, step = 2, 10, 2
     dataset = dataset_ops.Dataset.range(start, stop, step)
     self.assertDatasetProduces(dataset, expected_output=range(2, 10, 2))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testZeroStep(self):
     start, stop, step = 2, 10, 0
     with self.assertRaises(errors.InvalidArgumentError):
       dataset = dataset_ops.Dataset.range(start, stop, step)
       self.evaluate(dataset._variant_tensor)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNegativeStep(self):
     start, stop, step = 2, 10, -1
     dataset = dataset_ops.Dataset.range(start, stop, step)
     self.assertDatasetProduces(dataset, expected_output=range(2, 10, -1))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testStopLessThanStart(self):
     start, stop = 10, 2
     dataset = dataset_ops.Dataset.range(start, stop)
     self.assertDatasetProduces(dataset, expected_output=range(10, 2))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testStopLessThanStartWithPositiveStep(self):
     start, stop, step = 10, 2, 2
     dataset = dataset_ops.Dataset.range(start, stop, step)
     self.assertDatasetProduces(dataset, expected_output=range(10, 2, 2))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testStopLessThanStartWithNegativeStep(self):
     start, stop, step = 10, 2, -1
     dataset = dataset_ops.Dataset.range(start, stop, step)
diff --git a/tensorflow/python/data/kernel_tests/repeat_test.py b/tensorflow/python/data/kernel_tests/repeat_test.py
index 8a8537b30cf..c4262fcc08c 100644
--- a/tensorflow/python/data/kernel_tests/repeat_test.py
+++ b/tensorflow/python/data/kernel_tests/repeat_test.py
@@ -17,43 +17,33 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class RepeatTest(test_base.DatasetTestBase):
+class RepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  def testRepeatTensorDataset(self):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(count=[0, 3, 7])))
+  def testFiniteRepeat(self, count):
     """Test a dataset that repeats its input multiple times."""
     components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
-    # This placeholder can be fed when dataset-definition subgraph
-    # runs (i.e. `init_op` below) to configure the number of
-    # repetitions used in a particular iterator.
+    dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
+    self.assertEqual(
+        [c.shape for c in components],
+        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
+    self.assertDatasetProduces(dataset, [components] * count)
 
-    def do_test(count):
-      dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
-      self.assertEqual(
-          [c.shape for c in components],
-          [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
-      self.assertDatasetProduces(dataset, [components] * count)
-
-    # Test a finite repetition.
-    do_test(3)
-
-    # test a different finite repetition.
-    do_test(7)
-
-    # Test an empty repetition.
-    do_test(0)
-
-    # Test an infinite repetition.
-    # NOTE(mrry): There's not a good way to test that the sequence
-    # actually is infinite.
+  @combinations.generate(test_base.default_test_combinations())
+  def testInfiniteRepeat(self):
+    # NOTE(mrry): There's not a good way to test that the sequence is infinite.
+    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
     dataset = dataset_ops.Dataset.from_tensors(components).repeat(-1)
     self.assertEqual(
         [c.shape for c in components],
@@ -64,7 +54,8 @@ class RepeatTest(test_base.DatasetTestBase):
       for component, result_component in zip(components, results):
         self.assertAllEqual(component, result_component)
 
-  def testRepeatRepeatTensorDataset(self):
+  @combinations.generate(test_base.default_test_combinations())
+  def testRepeatRepeat(self):
     """Test the composition of repeat datasets."""
     components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
     inner_count, outer_count = 7, 14
@@ -77,11 +68,6 @@ class RepeatTest(test_base.DatasetTestBase):
     self.assertDatasetProduces(dataset,
                                [components] * (inner_count * outer_count))
 
-  def testRepeatEmptyDataset(self):
-    """Test that repeating an empty dataset does not hang."""
-    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10).repeat(-1)
-    self.assertDatasetProduces(dataset, [])
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/kernel_tests/shard_test.py b/tensorflow/python/data/kernel_tests/shard_test.py
index 9fc70ff6075..5830b66d61c 100644
--- a/tensorflow/python/data/kernel_tests/shard_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_test.py
@@ -17,66 +17,79 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 
-@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
-class ShardTest(test_base.DatasetTestBase):
+class ShardTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testSimpleCase(self):
     dataset = dataset_ops.Dataset.range(10).shard(5, 2)
     self.assertDatasetProduces(dataset, expected_output=[2, 7])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNestedData(self):
     dataset_a = dataset_ops.Dataset.range(10)
     dataset_b = dataset_ops.Dataset.range(10, 0, -1)
     dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
     self.assertDatasetProduces(dataset, expected_output=[(2, 8), (7, 3)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOffsetZero(self):
     dataset = dataset_ops.Dataset.range(10).shard(5, 0)
     self.assertDatasetProduces(dataset, expected_output=[0, 5])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testOffsetGreaterNumShards(self):
     with self.assertRaises(errors.InvalidArgumentError):
       dataset = dataset_ops.Dataset.range(10).shard(5, 7)
       self.evaluate(self.getNext(dataset)())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNegativeOffset(self):
     with self.assertRaises(errors.InvalidArgumentError):
       dataset = dataset_ops.Dataset.range(10).shard(5, -3)
       self.evaluate(self.getNext(dataset)())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNegativeNumShards(self):
     with self.assertRaises(errors.InvalidArgumentError):
       dataset = dataset_ops.Dataset.range(10).shard(-3, 1)
       self.evaluate(self.getNext(dataset)())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testZeroNumShards(self):
     with self.assertRaises(errors.InvalidArgumentError):
       dataset = dataset_ops.Dataset.range(10).shard(0, 1)
       self.evaluate(self.getNext(dataset)())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testIteratorEndsBeforeFirstElem(self):
     dataset = dataset_ops.Dataset.range(1).shard(5, 2)
     self.assertDatasetProduces(dataset, expected_output=[])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testLargerWorkerPool(self):
     dataset = dataset_ops.Dataset.range(10).shard(7, 5)
     self.assertDatasetProduces(dataset, expected_output=[5])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testIndexEqualsNumShards(self):
     dataset = dataset_ops.Dataset.range(10).shard(5, 4)
     self.assertDatasetProduces(dataset, expected_output=[4, 9])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testIndexEqualsNumShards2(self):
     dataset = dataset_ops.Dataset.range(10).shard(4, 3)
     self.assertDatasetProduces(dataset, expected_output=[3, 7])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNumShardsLargerThanDataset(self):
     dataset = dataset_ops.Dataset.range(10).shard(20, 5)
     self.assertDatasetProduces(dataset, expected_output=[5])
diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py
index 7f801e1b5f4..c9d17b79016 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_test.py
@@ -40,7 +40,7 @@ from tensorflow.python.platform import test
 class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(test_base.default_test_combinations())
-  def testShuffleDataset(self):
+  def testBasic(self):
     components = (
         np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
         np.array([9.0, 10.0, 11.0, 12.0])
@@ -160,7 +160,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(
       combinations.times(
-          combinations.combine(tf_api_version=[1, 2], mode="graph"),
+          test_base.graph_only_combinations(),
           combinations.combine(reshuffle=[True, False]),
           combinations.combine(graph_seed=38, op_seed=None) +
           combinations.combine(graph_seed=None, op_seed=42) +
@@ -188,7 +188,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
   # TODO(b/117581999): enable this test for eager-mode.
   @combinations.generate(
       combinations.times(
-          combinations.combine(tf_api_version=[1, 2], mode="graph"),
+          test_base.graph_only_combinations(),
           combinations.combine(
               reshuffle=[True, False], initializable=[True, False])))
   def testMultipleIterators(self, reshuffle, initializable):
@@ -278,7 +278,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(
       combinations.times(
-          combinations.combine(tf_api_version=[1, 2], mode="eager"),
+          test_base.eager_only_combinations(),
           combinations.combine(reshuffle=[True, False], seed=[None, 42])))
   def testReshuffleSeparateTransformations(self, reshuffle, seed):
     dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/python/data/kernel_tests/skip_test.py b/tensorflow/python/data/kernel_tests/skip_test.py
index 74dc8b7f55c..176893d90d2 100644
--- a/tensorflow/python/data/kernel_tests/skip_test.py
+++ b/tensorflow/python/data/kernel_tests/skip_test.py
@@ -17,46 +17,30 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class SkipTest(test_base.DatasetTestBase):
+class SkipTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  def testSkipTensorDataset(self):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(count=[-1, 0, 4, 10, 25])))
+  def testBasic(self, count):
     components = (np.arange(10),)
-
-    def do_test(count):
-      dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count)
-      self.assertEqual(
-          [c.shape[1:] for c in components],
-          [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
-      start_range = min(count, 10) if count != -1 else 10
-      self.assertDatasetProduces(
-          dataset,
-          [tuple(components[0][i:i + 1]) for i in range(start_range, 10)])
-
-    # Skip fewer than input size, we should skip
-    # the first 4 elements and then read the rest.
-    do_test(4)
-
-    # Skip more than input size: get nothing.
-    do_test(25)
-
-    # Skip exactly input size.
-    do_test(10)
-
-    # Set -1 for 'count': skip the entire dataset.
-    do_test(-1)
-
-    # Skip nothing
-    do_test(0)
-
+    dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count)
+    self.assertEqual(
+        [c.shape[1:] for c in components],
+        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
+    start_range = min(count, 10) if count != -1 else 10
+    self.assertDatasetProduces(
+        dataset,
+        [tuple(components[0][i:i + 1]) for i in range(start_range, 10)])
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/take_test.py b/tensorflow/python/data/kernel_tests/take_test.py
index 665ed59a7bc..14796551e16 100644
--- a/tensorflow/python/data/kernel_tests/take_test.py
+++ b/tensorflow/python/data/kernel_tests/take_test.py
@@ -17,40 +17,30 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class TakeTest(test_base.DatasetTestBase):
+class TakeTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  def testTakeTensorDataset(self):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(count=[-1, 0, 4, 10, 25])))
+  def testBasic(self, count):
     components = (np.arange(10),)
+    dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count)
+    self.assertEqual(
+        [c.shape[1:] for c in components],
+        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
+    num_output = min(count, 10) if count != -1 else 10
+    self.assertDatasetProduces(
+        dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)])
 
-    def do_test(count):
-      dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count)
-      self.assertEqual(
-          [c.shape[1:] for c in components],
-          [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
-      num_output = min(count, 10) if count != -1 else 10
-      self.assertDatasetProduces(
-          dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)])
-
-    # Take fewer than input size
-    do_test(4)
-
-    # Take more than input size
-    do_test(25)
-
-    # Take all of input
-    do_test(-1)
-
-    # Take nothing
-    do_test(0)
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index 6dfee4cc0f7..60796b178bf 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -58,7 +58,11 @@ class DatasetTestBase(test.TestCase):
 
   def assertValuesEqual(self, expected, actual):
     """Asserts that two values are equal."""
-    if sparse_tensor.is_sparse(expected):
+    if isinstance(expected, dict):
+      self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
+      for k in expected.keys():
+        self.assertValuesEqual(expected[k], actual[k])
+    elif sparse_tensor.is_sparse(expected):
       self.assertAllEqual(expected.indices, actual.indices)
       self.assertAllEqual(expected.values, actual.values)
       self.assertAllEqual(expected.dense_shape, actual.dense_shape)
diff --git a/tensorflow/python/data/kernel_tests/text_line_dataset_test.py b/tensorflow/python/data/kernel_tests/text_line_dataset_test.py
index c62d4ec8270..35b479faa21 100644
--- a/tensorflow/python/data/kernel_tests/text_line_dataset_test.py
+++ b/tensorflow/python/data/kernel_tests/text_line_dataset_test.py
@@ -21,11 +21,12 @@ import gzip
 import os
 import zlib
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import readers
-from tensorflow.python.eager import context
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
 
@@ -37,8 +38,7 @@ except ImportError:
   psutil_import_succeeded = False
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class TextLineDatasetTest(test_base.DatasetTestBase):
+class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def _lineText(self, f, l):
     return compat.as_bytes("%d: %d" % (f, l))
@@ -76,7 +76,11 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
 
     return filenames
 
-  def _testTextLineDataset(self, compression_type=None):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(compression_type=[None, "GZIP", "ZLIB"])))
+  def testTextLineDataset(self, compression_type):
     test_filenames = self._createFiles(
         2, 5, crlf=True, compression_type=compression_type)
 
@@ -115,6 +119,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
         expected_output=[[self._lineText(0, i) for i in range(5)],
                          [self._lineText(1, i) for i in range(5)]] * 10)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testTextLineDatasetParallelRead(self):
     test_filenames = self._createFiles(10, 10)
     files = dataset_ops.Dataset.from_tensor_slices(test_filenames).repeat(10)
@@ -125,15 +130,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
     self.assertDatasetProduces(
         dataset, expected_output=expected_output * 10, assert_items_equal=True)
 
-  def testTextLineDatasetNoCompression(self):
-    self._testTextLineDataset()
-
-  def testTextLineDatasetGzipCompression(self):
-    self._testTextLineDataset(compression_type="GZIP")
-
-  def testTextLineDatasetZlibCompression(self):
-    self._testTextLineDataset(compression_type="ZLIB")
-
+  @combinations.generate(test_base.default_test_combinations())
   def testTextLineDatasetBuffering(self):
     test_filenames = self._createFiles(2, 5, crlf=True)
 
@@ -143,33 +140,33 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
       expected_output.extend([self._lineText(j, i) for i in range(5)])
     self.assertDatasetProduces(repeat_dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.eager_only_combinations())
   def testIteratorResourceCleanup(self):
     filename = os.path.join(self.get_temp_dir(), "text.txt")
     with open(filename, "wt") as f:
       for i in range(3):
         f.write("%d\n" % (i,))
-    with context.eager_mode():
-      first_iterator = iter(readers.TextLineDataset(filename))
-      self.assertEqual(b"0", next(first_iterator).numpy())
-      second_iterator = iter(readers.TextLineDataset(filename))
-      self.assertEqual(b"0", next(second_iterator).numpy())
-      # Eager kernel caching is based on op attributes, which includes the
-      # Dataset's output shape. Create a different kernel to test that they
-      # don't create resources with the same names.
-      different_kernel_iterator = iter(
-          readers.TextLineDataset(filename).repeat().batch(16))
-      self.assertEqual([16], next(different_kernel_iterator).shape)
-      # Remove our references to the Python Iterator objects, which (assuming no
-      # reference cycles) is enough to trigger DestroyResourceOp and close the
-      # partially-read files.
-      del first_iterator
-      del second_iterator
-      del different_kernel_iterator
-      if not psutil_import_succeeded:
-        self.skipTest(
-            "psutil is required to check that we've closed our files.")
-      open_files = psutil.Process().open_files()
-      self.assertNotIn(filename, [open_file.path for open_file in open_files])
+    first_iterator = iter(readers.TextLineDataset(filename))
+    self.assertEqual(b"0", next(first_iterator).numpy())
+    second_iterator = iter(readers.TextLineDataset(filename))
+    self.assertEqual(b"0", next(second_iterator).numpy())
+    # Eager kernel caching is based on op attributes, which includes the
+    # Dataset's output shape. Create a different kernel to test that they
+    # don't create resources with the same names.
+    different_kernel_iterator = iter(
+        readers.TextLineDataset(filename).repeat().batch(16))
+    self.assertEqual([16], next(different_kernel_iterator).shape)
+    # Remove our references to the Python Iterator objects, which (assuming no
+    # reference cycles) is enough to trigger DestroyResourceOp and close the
+    # partially-read files.
+    del first_iterator
+    del second_iterator
+    del different_kernel_iterator
+    if not psutil_import_succeeded:
+      self.skipTest(
+          "psutil is required to check that we've closed our files.")
+    open_files = psutil.Process().open_files()
+    self.assertNotIn(filename, [open_file.path for open_file in open_files])
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py b/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py
index 5cf8308a55f..792c4926640 100644
--- a/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py
+++ b/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py
@@ -21,31 +21,31 @@ import gzip
 import os
 import zlib
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import test_util
 from tensorflow.python.lib.io import python_io
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class TFRecordDatasetTest(test_base.DatasetTestBase):
+class TFRecordDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def setUp(self):
     super(TFRecordDatasetTest, self).setUp()
     self._num_files = 2
     self._num_records = 7
-
     self.test_filenames = self._createFiles()
 
-  def dataset_fn(self,
-                 filenames,
-                 compression_type="",
-                 num_epochs=1,
-                 batch_size=None):
+  def _dataset_factory(self,
+                       filenames,
+                       compression_type="",
+                       num_epochs=1,
+                       batch_size=None):
 
     repeat_dataset = readers.TFRecordDataset(
         filenames, compression_type).repeat(num_epochs)
@@ -67,6 +67,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
       writer.close()
     return filenames
 
+  @combinations.generate(test_base.default_test_combinations())
   def testTFRecordDatasetConstructorErrorsTensorInput(self):
     with self.assertRaisesRegex(TypeError,
                                 "filenames.*must be.*Tensor.*string"):
@@ -78,37 +79,40 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
     with self.assertRaises(Exception):
       readers.TFRecordDataset(object())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadOneEpoch(self):
     # Basic test: read from file 0.
-    dataset = self.dataset_fn(self.test_filenames[0])
+    dataset = self._dataset_factory(self.test_filenames[0])
     self.assertDatasetProduces(
         dataset,
         expected_output=[self._record(0, i) for i in range(self._num_records)])
 
     # Basic test: read from file 1.
-    dataset = self.dataset_fn(self.test_filenames[1])
+    dataset = self._dataset_factory(self.test_filenames[1])
     self.assertDatasetProduces(
         dataset,
         expected_output=[self._record(1, i) for i in range(self._num_records)])
 
     # Basic test: read from both files.
-    dataset = self.dataset_fn(self.test_filenames)
+    dataset = self._dataset_factory(self.test_filenames)
     expected_output = []
     for j in range(self._num_files):
       expected_output.extend(
           [self._record(j, i) for i in range(self._num_records)])
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadTenEpochs(self):
-    dataset = self.dataset_fn(self.test_filenames, num_epochs=10)
+    dataset = self._dataset_factory(self.test_filenames, num_epochs=10)
     expected_output = []
     for j in range(self._num_files):
       expected_output.extend(
           [self._record(j, i) for i in range(self._num_records)])
     self.assertDatasetProduces(dataset, expected_output=expected_output * 10)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadTenEpochsOfBatches(self):
-    dataset = self.dataset_fn(
+    dataset = self._dataset_factory(
         self.test_filenames, num_epochs=10, batch_size=self._num_records)
     expected_output = []
     for j in range(self._num_files):
@@ -116,6 +120,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
           [self._record(j, i) for i in range(self._num_records)])
     self.assertDatasetProduces(dataset, expected_output=expected_output * 10)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadZlibFiles(self):
     zlib_files = []
     for i, fn in enumerate(self.test_filenames):
@@ -130,9 +135,10 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
     for j in range(self._num_files):
       expected_output.extend(
           [self._record(j, i) for i in range(self._num_records)])
-    dataset = self.dataset_fn(zlib_files, compression_type="ZLIB")
+    dataset = self._dataset_factory(zlib_files, compression_type="ZLIB")
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadGzipFiles(self):
     gzip_files = []
     for i, fn in enumerate(self.test_filenames):
@@ -145,9 +151,10 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
     for j in range(self._num_files):
       expected_output.extend(
           [self._record(j, i) for i in range(self._num_records)])
-    dataset = self.dataset_fn(gzip_files, compression_type="GZIP")
+    dataset = self._dataset_factory(gzip_files, compression_type="GZIP")
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadWithBuffer(self):
     one_mebibyte = 2**20
     dataset = readers.TFRecordDataset(
@@ -158,6 +165,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
           [self._record(j, i) for i in range(self._num_records)])
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadFromDatasetOfFiles(self):
     files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames)
     expected_output = []
@@ -167,6 +175,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
     dataset = readers.TFRecordDataset(files)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadTenEpochsFromDatasetOfFilesInParallel(self):
     files = dataset_ops.Dataset.from_tensor_slices(
         self.test_filenames).repeat(10)
diff --git a/tensorflow/python/data/kernel_tests/unbatch_test.py b/tensorflow/python/data/kernel_tests/unbatch_test.py
index 5bb4852d534..44d949385b0 100644
--- a/tensorflow/python/data/kernel_tests/unbatch_test.py
+++ b/tensorflow/python/data/kernel_tests/unbatch_test.py
@@ -23,11 +23,11 @@ import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import string_ops
@@ -36,13 +36,14 @@ from tensorflow.python.platform import test
 from tensorflow.python.util import compat
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchWithUnknownRankInput(self):
     dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch()
     self.assertDatasetProduces(dataset, range(4))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchScalarDataset(self):
     data = tuple([math_ops.range(10) for _ in range(3)])
     data = dataset_ops.Dataset.from_tensor_slices(data)
@@ -54,12 +55,14 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchNestedDataset(self):
     data = dataset_ops.Dataset.from_tensors(
         [dataset_ops.Dataset.range(10) for _ in range(10)])
     data = data.unbatch().flat_map(lambda x: x)
     self.assertDatasetProduces(data, list(range(10)) * 10)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchDatasetWithStrings(self):
     data = tuple([math_ops.range(10) for _ in range(3)])
     data = dataset_ops.Dataset.from_tensor_slices(data)
@@ -73,6 +76,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         data, [(i, compat.as_bytes(str(i)), i) for i in range(10)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchDatasetWithSparseTensor(self):
     st = sparse_tensor.SparseTensorValue(
         indices=[[i, i] for i in range(10)],
@@ -87,6 +91,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     ]
     self.assertDatasetProduces(data, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchDatasetWithDenseSparseAndRaggedTensor(self):
     st = sparse_tensor.SparseTensorValue(
         indices=[[i, i] for i in range(10)],
@@ -104,6 +109,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         data, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchDatasetWithRaggedTensor(self):
     rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
                                             [[5]], [[6]], [[7]], [[8]], [[9]]])
@@ -119,6 +125,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertDatasetProduces(
         data, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchSingleElementTupleDataset(self):
     data = tuple([(math_ops.range(10),) for _ in range(3)])
     data = dataset_ops.Dataset.from_tensor_slices(data)
@@ -130,6 +137,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchMultiElementTupleDataset(self):
     data = tuple([(math_ops.range(10 * i, 10 * i + 10),
                    array_ops.fill([10], "hi")) for i in range(3)])
@@ -146,6 +154,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
         data,
         [((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")) for i in range(10)])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchEmpty(self):
     data = dataset_ops.Dataset.from_tensors(
         (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
@@ -153,15 +162,15 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
     data = data.unbatch()
     self.assertDatasetProduces(data, [])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchStaticShapeMismatch(self):
     data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
                                              np.arange(9)))
     with self.assertRaises(ValueError):
       data.unbatch()
 
-  # Note: dynamic shape mismatch is graph specific test.
-  @test_util.run_deprecated_v1
-  def testSkipEagerUnbatchDynamicShapeMismatch(self):
+  @combinations.generate(test_base.graph_only_combinations())
+  def testUnbatchDynamicShapeMismatch(self):
     ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
     ph2 = array_ops.placeholder(dtypes.int32, shape=None)
     data = dataset_ops.Dataset.from_tensors((ph1, ph2))
@@ -190,6 +199,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
       with self.assertRaises(errors.InvalidArgumentError):
         self.evaluate(next_element)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnbatchDatasetWithUintDtypes(self):
     components = (
         np.tile(np.array([[0], [1], [2], [3]], dtype=np.uint8), 2),
diff --git a/tensorflow/python/data/kernel_tests/window_test.py b/tensorflow/python/data/kernel_tests/window_test.py
index 122e874f0a0..98b453a5900 100644
--- a/tensorflow/python/data/kernel_tests/window_test.py
+++ b/tensorflow/python/data/kernel_tests/window_test.py
@@ -24,43 +24,32 @@ from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.eager import context
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @parameterized.named_parameters(
-      ("1", 20, 14, 7, 1),
-      ("2", 20, 17, 9, 1),
-      ("3", 20, 14, 14, 1),
-      ("4", 20, 10, 14, 1),
-      ("5", 20, 14, 19, 1),
-      ("6", 20, 4, 1, 2),
-      ("7", 20, 2, 1, 6),
-      ("8", 20, 4, 7, 2),
-      ("9", 20, 2, 7, 6),
-      ("10", 1, 10, 4, 1),
-      ("11", 0, 10, 4, 1),
-      ("12", 20, 14, 7, 1, False),
-      ("13", 20, 17, 9, 1, False),
-      ("14", 20, 14, 14, 1, False),
-      ("15", 20, 10, 14, 1, False),
-      ("16", 20, 14, 19, 1, False),
-      ("17", 20, 4, 1, 2, False),
-      ("18", 20, 2, 1, 6, False),
-      ("19", 20, 4, 7, 2, False),
-      ("20", 20, 2, 7, 6, False),
-      ("21", 1, 10, 4, 1, False),
-      ("22", 0, 10, 4, 1, False),
-  )
-  def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              count=20,
+              size=[10, 14, 17],
+              shift=[7, 14],
+              stride=[1, 2, 6],
+              drop_remainder=[True, False]) + combinations.combine(
+                  count=[0, 1],
+                  size=10,
+                  shift=4,
+                  stride=1,
+                  drop_remainder=[True, False])))
+  def testWindowDataset(self, count, size, shift, stride, drop_remainder):
     """Tests a dataset that slides a window its input elements."""
     components = (np.arange(7),
                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
@@ -111,11 +100,12 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  @parameterized.named_parameters(
-      ("1", 14, 0, 3, 1),
-      ("2", 14, 3, 0, 1),
-      ("3", 14, 3, 3, 0),
-  )
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(count=20, size=0, shift=3, stride=1) +
+          combinations.combine(count=20, size=3, shift=0, stride=1) +
+          combinations.combine(count=20, size=3, shift=3, stride=0)))
   def testWindowDatasetInvalid(self, count, size, shift, stride):
     with self.assertRaises(errors.InvalidArgumentError):
       ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window(
@@ -123,12 +113,14 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
           stride=stride).flat_map(lambda x: x.batch(batch_size=size))
       self.evaluate(ds._variant_tensor)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWindowDifferentNestedStructures(self):
     ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2)
     self.getNext(ds)
     ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2)
     self.getNext(ds)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWindowSparse(self):
 
     def _sparse(i):
@@ -148,6 +140,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
     ]
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWindowSparseWithDifferentDenseShapes(self):
 
     def _sparse(i):
@@ -177,6 +170,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
               dense_shape=[5, i * 3 + 5 - 1]))
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNestedWindowSparse(self):
 
     def _sparse(i):
@@ -205,6 +199,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
     ]
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWindowShapeError(self):
 
     def generator():
@@ -222,6 +217,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
             r"Cannot batch tensors with different shapes in component 0. "
             r"First element had shape \[3\] and element 2 had shape \[4\]."))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWindowIgnoreErrors(self):
     input_values = np.float32([1., np.nan, 2., np.nan, 3.])
     dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
@@ -232,6 +228,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
         dataset, expected_output=[np.float32([1., 2.]),
                                   np.float32([2., 3.])])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNestedOutput(self):
     if not context.executing_eagerly():
       self.skipTest("self.evaluate() does not work with a dataset")
diff --git a/tensorflow/python/data/kernel_tests/zip_test.py b/tensorflow/python/data/kernel_tests/zip_test.py
index 72f739e4e4e..c63091754c3 100644
--- a/tensorflow/python/data/kernel_tests/zip_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_test.py
@@ -17,66 +17,68 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class ZipTest(test_base.DatasetTestBase):
+def _dataset_factory(components):
+  datasets = tuple([
+      dataset_ops.Dataset.from_tensor_slices(component)
+      for component in components
+  ])
+  return dataset_ops.Dataset.zip(datasets)
 
-  def testZipDataset(self):
 
-    def dataset_fn(components):
-      datasets = tuple([
-          dataset_ops.Dataset.from_tensor_slices(component)
-          for component in components
-      ])
-      return dataset_ops.Dataset.zip(datasets)
+class ZipTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-    equal_length_components = [
+  @combinations.generate(test_base.default_test_combinations())
+  def testZipEqual(self):
+    components = [
         np.tile(np.array([[1], [2], [3], [4]]), 20),
         np.tile(np.array([[12], [13], [14], [15]]), 22),
         np.array([37.0, 38.0, 39.0, 40.0])
     ]
-
-    get_next = self.getNext(dataset_fn(equal_length_components))
+    get_next = self.getNext(_dataset_factory(components))
     for i in range(4):
       results = self.evaluate(get_next())
-      for component, result_component in zip(equal_length_components, results):
+      for component, result_component in zip(components, results):
         self.assertAllEqual(component[i], result_component)
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-    variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
-    get_next = self.getNext(dataset_fn(variable_length_components))
+  @combinations.generate(test_base.default_test_combinations())
+  def testZipUnequal(self):
+    components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
+    get_next = self.getNext(_dataset_factory(components))
     for i in range(2):
       results = self.evaluate(get_next())
-      for component, result_component in zip(variable_length_components,
-                                             results):
+      for component, result_component in zip(components, results):
         self.assertAllEqual(component[i], result_component)
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  def testNestedZipDataset(self):
+  @combinations.generate(test_base.default_test_combinations())
+  def testNested(self):
 
-    equal_length_components = [
+    components = [
         np.tile(np.array([[1], [2], [3], [4]]), 20),
         np.tile(np.array([[12], [13], [14], [15]]), 22),
         np.array([37.0, 38.0, 39.0, 40.0])
     ]
     datasets = [
         dataset_ops.Dataset.from_tensor_slices(component)
-        for component in equal_length_components
+        for component in components
     ]
     dataset = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
 
@@ -88,9 +90,9 @@ class ZipTest(test_base.DatasetTestBase):
     get_next = self.getNext(dataset)
     for i in range(4):
       result1, (result2, result3) = self.evaluate(get_next())
-      self.assertAllEqual(equal_length_components[0][i], result1)
-      self.assertAllEqual(equal_length_components[1][i], result2)
-      self.assertAllEqual(equal_length_components[2][i], result3)
+      self.assertAllEqual(components[0][i], result1)
+      self.assertAllEqual(components[1][i], result2)
+      self.assertAllEqual(components[2][i], result3)
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
     with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index f67dec9d720..f3367023a7b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -66,7 +66,6 @@ from tensorflow.python.ops import gen_io_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import script_ops
 from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.tracking import base as tracking_base
 from tensorflow.python.training.tracking import tracking
 from tensorflow.python.util import deprecation
@@ -2437,26 +2436,25 @@ class DatasetV1Adapter(DatasetV1):
 
 def _ensure_same_dataset_graph(dataset):
   """Walks the dataset graph to ensure all datasets come from the same graph."""
+  # pylint: disable=protected-access
   current_graph = ops.get_default_graph()
   bfs_q = Queue.Queue()
-  bfs_q.put(dataset)  # pylint: disable=protected-access
+  bfs_q.put(dataset)
   visited = []
   while not bfs_q.empty():
     ds = bfs_q.get()
     visited.append(ds)
-    ds_graph = ds._graph  # pylint: disable=protected-access
+    ds_graph = ds._graph
     if current_graph != ds_graph:
-      logging.warning("The graph (" + str(current_graph) + ") of the iterator "
-                      "is different from the graph (" + str(ds_graph) + ") "
-                      "the dataset: " + str(ds._variant_tensor) + " was "  # pylint: disable=protected-access
-                      "created in. If you are using the Estimator API, "
-                      "make sure that no part of the dataset returned by the "
-                      "`input_fn` function is defined outside the `input_fn` "
-                      "function. Please ensure that all datasets in the "
-                      "pipeline are created in the same graph as the iterator. "
-                      "NOTE: This warning will become an error in future "
-                      "versions of TensorFlow.")
-    for input_ds in ds._inputs():  # pylint: disable=protected-access
+      raise ValueError(
+          "The graph (" + str(current_graph) + ") of the iterator is different "
+          "from the graph (" + str(ds_graph) + ") the dataset: " +
+          str(ds._variant_tensor) + " was  created in. If you are using the "
+          "Estimator API, make sure that no part of the dataset returned by "
+          "the `input_fn` function is defined outside the `input_fn` function. "
+          "Please ensure that all datasets in the pipeline are created in the "
+          "same graph as the iterator.")
+    for input_ds in ds._inputs():
       if input_ds not in visited:
         bfs_q.put(input_ds)
 

From b40f26fdf0490949476c56e654fe11ed3b9cdf46 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Mon, 2 Dec 2019 11:29:03 -0800
Subject: [PATCH 156/279] Delete the string that looks like a privatee key from
 oauth_client_test. the test is just as happy with a random string.

PiperOrigin-RevId: 283380661
Change-Id: I89b5ca1eca8e850ad1abddfde8be9a851643ef9e
---
 tensorflow/core/platform/cloud/oauth_client_test.cc | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 1b04b1cf827..88e11a58d36 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -35,7 +35,7 @@ constexpr char kTestData[] = "core/platform/cloud/testdata/";
 
 constexpr char kTokenJson[] = R"(
     {
-      "access_token":"1/fFAGRNJru1FTz70BzhT3Zg",
+      "access_token":"WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY",
       "expires_in":3920,
       "token_type":"Bearer"
     })";
@@ -56,7 +56,7 @@ TEST(OAuthClientTest, ParseOAuthResponse) {
   uint64 expiration_timestamp;
   TF_EXPECT_OK(OAuthClient().ParseOAuthResponse(kTokenJson, request_timestamp,
                                                 &token, &expiration_timestamp));
-  EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token);
+  EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
   EXPECT_EQ(4020, expiration_timestamp);
 }
 
@@ -87,7 +87,7 @@ TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) {
   TF_EXPECT_OK(client.GetTokenFromRefreshTokenJson(
       json, "https://www.googleapis.com/oauth2/v3/token", &token,
       &expiration_timestamp));
-  EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token);
+  EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
   EXPECT_EQ(13920, expiration_timestamp);
 }
 
@@ -113,7 +113,7 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
   TF_EXPECT_OK(client.GetTokenFromServiceAccountJson(
       json, "https://www.googleapis.com/oauth2/v3/token",
       "https://test-token-scope.com", &token, &expiration_timestamp));
-  EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token);
+  EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
   EXPECT_EQ(13920, expiration_timestamp);
 
   // Now look at the JWT claim that was sent to the OAuth server.

From c6f0c438213914a1ffef5f198f270a83808a41a3 Mon Sep 17 00:00:00 2001
From: Yunxing Dai 
Date: Mon, 2 Dec 2019 11:29:57 -0800
Subject: [PATCH 157/279] [XLA] Dynamic padder: Support dynamic concat.

PiperOrigin-RevId: 283380862
Change-Id: Ic3c6de5054003b276de400ca8a299ba9c2fec000
---
 .../service/dynamic_dimension_inference.cc    | 43 +++++++++++++--
 .../compiler/xla/service/dynamic_padder.cc    | 55 ++++++++++++++++++-
 .../xla/service/dynamic_padder_test.cc        | 45 +++++++++++++++
 3 files changed, 137 insertions(+), 6 deletions(-)

diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index 1274c095b95..14ea6f988cb 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -419,14 +419,47 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution(
 
 Status DynamicDimensionInferenceVisitor::HandleConcatenate(
     HloInstruction* hlo) {
+  // First handle concatenate dimensions. We do this by iterating through all
+  // operands while tracking both dynamic and static dimensions.
+
+  // static_size is used to keep track of the concated size of static
+  // dimensions.
+  int64 static_size = 0;
+  std::vector dynamic_concat_dims;
+  for (int64 i = 0; i < hlo->operand_count(); ++i) {
+    HloInstruction* dynamic_size = parent_->GetDynamicSize(
+        hlo->mutable_operand(i), {}, hlo->concatenate_dimension());
+    if (dynamic_size == nullptr) {
+      // This is a static dimension.
+      static_size +=
+          hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension());
+    } else {
+      dynamic_concat_dims.push_back(dynamic_size);
+    }
+  }
+  // If concat dimension is dynamic, calculate its size by summing up static
+  // dims and dynamic dims together.
+  if (!dynamic_concat_dims.empty()) {
+    HloInstruction* dim_size_total =
+        hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
+            LiteralUtil::CreateR0(static_size)));
+    for (HloInstruction* dynamic_dim : dynamic_concat_dims) {
+      dim_size_total = hlo->parent()->AddInstruction(
+          HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd,
+                                       dim_size_total, dynamic_dim));
+    }
+    parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
+                            dim_size_total, {.stride = 1, .multiple_of = 1});
+  }
+
+  // Simply pass through non-concat dynamic dimensions.
   return ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
                int64 operand_index, HloInstruction* dynamic_size,
                DimensionConstraint constraint) {
         int64 concatenate_dimension = hlo->concatenate_dimension();
         if (concatenate_dimension == dimension) {
-          return Unimplemented("Dynamic concatenation is not supported yet: %s",
-                               operand->ToString());
+          return Status::OK();
         }
         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
                                 constraint);
@@ -1318,9 +1351,9 @@ Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
     DynamicDimension dynamic_dimension{inst, index, dim};
     auto iter = dynamic_mapping_.find(dynamic_dimension);
     if (iter != dynamic_mapping_.end()) {
-      dynamic_mapping_.try_emplace(dynamic_dimension_new, iter->second);
-      constraint_mapping_.try_emplace(dynamic_dimension_new,
-                                      constraint_mapping_[dynamic_dimension]);
+      dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
+      constraint_mapping_.insert(
+          {dynamic_dimension_new, constraint_mapping_[dynamic_dimension]});
       auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
       iter.first->second.emplace(dynamic_dimension_new);
     }
diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc
index 8d92b7e985a..c94a2594f3b 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder.cc
@@ -567,7 +567,55 @@ Status RewriteDynamicReshapeSingleDim(
   }
   return Status::OK();
 }
-
+StatusOr RewriteDynamicConcat(
+    HloInstruction* concat,
+    DynamicDimensionInference* dynamic_dimension_inference) {
+  const int64 concat_dim = concat->concatenate_dimension();
+  HloComputation* comp = concat->parent();
+  if (dynamic_dimension_inference->GetDynamicSize(concat, {}, concat_dim) ==
+      nullptr) {
+    // Concat dimension is not dynamic -- no rewrite needed.
+    return false;
+  }
+  std::vector offsets;
+  for (int64 i = 0; i < concat->shape().dimensions_size(); ++i) {
+    offsets.push_back(comp->AddInstruction(
+        HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))));
+  }
+  HloInstruction* rewritten_concat = concat;
+  // Keep track of previous users before rewrite so that we can update their
+  // operands later.
+  auto prev_users = concat->users();
+  for (int64 i = 0; i < concat->operand_count(); ++i) {
+    // Rewrite the concat by dynamic update slicing operand into the concat dim.
+    HloInstruction* operand = concat->mutable_operand(i);
+    rewritten_concat =
+        comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+            rewritten_concat->shape(), rewritten_concat, operand, offsets));
+    // Update the offset of concat dimension by adding the size of the concat
+    // dimension of the operand to it.
+    HloInstruction* dynamic_size =
+        dynamic_dimension_inference->GetDynamicSize(operand, {}, concat_dim);
+    if (dynamic_size == nullptr) {
+      HloInstruction* static_size = comp->AddInstruction(
+          HloInstruction::CreateConstant(LiteralUtil::CreateR0(
+              operand->shape().dimensions(concat_dim))));
+      offsets[concat_dim] = comp->AddInstruction(HloInstruction::CreateBinary(
+          ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
+          static_size));
+    } else {
+      offsets[concat_dim] = comp->AddInstruction(HloInstruction::CreateBinary(
+          ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
+          dynamic_size));
+    }
+  }
+  for (HloInstruction* user : prev_users) {
+    TF_RETURN_IF_ERROR(concat->ReplaceUseWith(user, rewritten_concat));
+  }
+  TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
+      concat, rewritten_concat, {}));
+  return true;
+}
 StatusOr RewriteDynamicReshape(
     HloInstruction* reshape,
     DynamicDimensionInference* dynamic_dimension_inference) {
@@ -709,6 +757,11 @@ StatusOr DynamicPadder::Run(HloModule* module) {
 
   for (HloComputation* computation : module->computations()) {
     for (HloInstruction* inst : computation->instructions()) {
+      if (inst->opcode() == HloOpcode::kConcatenate) {
+        TF_ASSIGN_OR_RETURN(
+            changed, RewriteDynamicConcat(inst, &dynamic_dimension_inference));
+        continue;
+      }
       for (int64 operand_num = 0; operand_num < inst->operand_count();
            ++operand_num) {
         HloInstruction* original_operand = inst->mutable_operand(operand_num);
diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
index 6c3f0bec493..0e60e420d47 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
@@ -496,6 +496,51 @@ ENTRY main {
   EXPECT_EQ(result, expected);
 }
 
+XLA_TEST_F(ExecutionTest, DynamicConcat) {
+  // Concatting a list of {dynamic_operand, static_operand, dynamic_operand}.
+  const string hlo_text = R"(
+HloModule DynamicConcat
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+  lhs = s32[] parameter(0)
+  rhs = s32[] parameter(1)
+  ROOT add = s32[] add(lhs, rhs)
+}
+
+ENTRY main {
+  param_0 = s32[3] parameter(0)
+  param_1 = s32[3] parameter(1)
+  param_2 = s32[3] parameter(2)
+  size = s32[] constant(2)
+  param_padded_0 = s32[3] set-dimension-size(param_0, size), dimensions={0}
+  param_padded_2 = s32[3] set-dimension-size(param_2, size), dimensions={0}
+  %concatenate = s32[9]
+    concatenate(s32[3] param_padded_0, s32[3] param_1, s32[3] param_padded_2),
+    dimensions={0}
+  init = s32[] constant(0)
+  ROOT reduce = s32[] reduce(concatenate, init),
+      dimensions={0},
+      to_apply=update_s32
+}
+)";
+
+  // Input has upper bound of 3, dynamic dimension is 2. Using -1 as padding.
+  Literal operand_0 =
+      LiteralUtil::CreateR1({1, 2, -1});  // Dynamic operand.
+  Literal operand_1 =
+      LiteralUtil::CreateR1({3, 4, 5});  // Static operand.
+  Literal operand_2 =
+      LiteralUtil::CreateR1({6, 7, -1});  // Dynamic operand.
+  auto module = GetHloModule(hlo_text);
+
+  Literal result =
+      PadAndExecute(std::move(module), {&operand_0, &operand_1, &operand_2});
+
+  Literal expected = LiteralUtil::CreateR0(28);
+
+  EXPECT_EQ(result, expected);
+}
+
 XLA_TEST_F(ExecutionTest, DynamicDimensionReduce) {
   const string hlo_text = R"(
 HloModule TensorFlowScatterV1

From 3bc9dd3b776c26fbd8d4eba7917fe873a5596e40 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 11:37:15 -0800
Subject: [PATCH 158/279] Update Eigen to
 https://bitbucket.org/eigen/eigen/commits/111653465dd592db57128241ed9f3a2c59144028

PiperOrigin-RevId: 283382426
Change-Id: Ia81cb980e8ac700f9c2cbfe8ae4f7b221ebd70e9
---
 tensorflow/workspace.bzl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 9f1d4cce63a..a79444c221c 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -172,11 +172,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
         name = "eigen_archive",
         build_file = clean_dep("//third_party:eigen.BUILD"),
         patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
-        sha256 = "6d8ed482addd14892d7b0bd98fec2c02f18fdab97775bda68c3f2a99ffb190fb",
-        strip_prefix = "eigen-eigen-66be6c76fc01",
+        sha256 = "9edd4860b52813eaf8c023f0de1767ec58e2d67a290b718e6702469208ac5be1",
+        strip_prefix = "eigen-eigen-54bca9936424",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/66be6c76fc01.tar.gz",
-            "https://bitbucket.org/eigen/eigen/get/66be6c76fc01.tar.gz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/54bca9936424.tar.gz",
+            "https://bitbucket.org/eigen/eigen/get/54bca9936424.tar.gz",
         ],
     )
 

From 98203f896c2b0a25b4e5d3d62475465df7138f79 Mon Sep 17 00:00:00 2001
From: Juhyun Lee 
Date: Mon, 2 Dec 2019 11:56:57 -0800
Subject: [PATCH 159/279] Update supported CONV_2D version from 1 to 2.

The GPU delegate already supports dilated conv.

https://github.com/tensorflow/tensorflow/issues/34679

PiperOrigin-RevId: 283386353
Change-Id: I6eb9edb4083933455bce49a86983aa5a4f9c5483
---
 tensorflow/lite/delegates/gpu/common/model_builder.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index d7fe8938699..4aec15f0b67 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -843,7 +843,7 @@ class Conv2DOperationParser : public TFLiteOperationParser {
   Status IsSupported(const TfLiteContext* context,
                      const TfLiteNode* tflite_node,
                      const TfLiteRegistration* registration) final {
-    RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
+    RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
     RETURN_IF_ERROR(
         CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
     RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));

From 6b71000526adcd74ee30972218ce567a2277b85c Mon Sep 17 00:00:00 2001
From: Bruce Fontaine 
Date: Mon, 2 Dec 2019 12:09:31 -0800
Subject: [PATCH 160/279] Fix issue that prevents using both dense and sparse
 features in a single model with TPUEstimator.

PiperOrigin-RevId: 283389133
Change-Id: I2ba1756cb06d86525b3f7463d063cd948f0cee11
---
 tensorflow/python/tpu/tpu_embedding.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index 76483693dfe..d6e77815041 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -1044,7 +1044,7 @@ class TPUEmbedding(object):
         sample_indices = (
             enqueue_data.sample_indices
             if enqueue_data.sample_indices is not None else array_ops.zeros(
-                (0,), dtype=dtypes.int32))
+                (0,), dtype=dtypes.int64))
         sample_indices_list.append(sample_indices)
 
         aggregation_weights = (

From 58f9c9b0d9aed8b4036a31fb1f30dfcd26e8cab6 Mon Sep 17 00:00:00 2001
From: Dan Moldovan 
Date: Mon, 2 Dec 2019 12:23:47 -0800
Subject: [PATCH 161/279] Consistently process functions wrapped by
 functools.partial so that both the partial and its target function are
 handled in identical fashion. This fixes bugs where an application of
 functools.partial bypassed some checks, and at the same time requires
 explicit verifications in the test for callable metaclasses, which relied on
 this bug.

PiperOrigin-RevId: 283392058
Change-Id: Id07e7821fa455cf9b6cca469de259f06594730c2
---
 tensorflow/python/autograph/impl/api.py       | 40 ++++++------
 tensorflow/python/autograph/impl/api_test.py  | 14 ++--
 .../python/autograph/pyct/inspect_utils.py    | 21 +++++-
 .../autograph/pyct/inspect_utils_test.py      | 64 ++++++++++++++++---
 4 files changed, 108 insertions(+), 31 deletions(-)

diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index 26e766598d7..17104c10c1b 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -422,6 +422,27 @@ def converted_call(f,
     logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
     return _call_unconverted(f, args, kwargs, options, False)
 
+  if is_autograph_artifact(f):
+    logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f)
+    return _call_unconverted(f, args, kwargs, options)
+
+  # If this is a partial, unwrap it and redo all the checks.
+  if isinstance(f, functools.partial):
+    new_kwargs = {}
+    if f.keywords is not None:
+      new_kwargs = f.keywords
+    if kwargs is not None:
+      new_kwargs.update(kwargs)
+    new_args = f.args + args
+    logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args,
+                new_kwargs)
+    return converted_call(
+        f.func,
+        new_args,
+        new_kwargs,
+        caller_fn_scope=caller_fn_scope,
+        options=options)
+
   if inspect_utils.isbuiltin(f):
     if f is eval:
       return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
@@ -432,10 +453,6 @@ def converted_call(f,
     else:
       return py_builtins.overload_of(f)(*args)
 
-  if is_autograph_artifact(f):
-    logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f)
-    return _call_unconverted(f, args, kwargs, options)
-
   # TODO(b/122265385): Remove this bypass.
   if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
       _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
@@ -453,7 +470,7 @@ def converted_call(f,
   # Constructors are permanently whitelisted.
   # TODO(mdan): Toggle as experimental feature instead.
   # TODO(b/124016764): Remove this limitation.
-  if tf_inspect.isclass(f):
+  if inspect_utils.isconstructor(f):
     logging.log(2, 'Permanently whitelisted: %s: constructor', f)
     return _call_unconverted(f, args, kwargs, options)
 
@@ -484,19 +501,6 @@ def converted_call(f,
   # TODO(mdan): Move this entire block inside to_graph.
   try:  # Begin of transformation error guards
 
-    # Unwrap functools.partial objects
-    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
-    # TODO(b/120224672): This unwrapping should be done before the checks above.
-    while isinstance(f, functools.partial):
-      args = f.args + args
-      new_kwargs = {}
-      if f.keywords is not None:
-        new_kwargs.update(f.keywords)
-      if kwargs is not None:
-        new_kwargs.update(kwargs)
-      kwargs = new_kwargs
-      f = f.func
-
     if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
       # Regular functions
       target_entity = f
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index 2eac1fefd54..0188af10e2e 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -435,10 +435,7 @@ class ApiTest(test.TestCase):
         return inst
 
     tmc = TestMetaclass('TestClass', (), {})
-    # This functools.partial will hide the class form the constructor
-    # check. Not ideal. See b/120224672.
-    tc = api.converted_call(
-        functools.partial(tmc), (), None, options=DEFAULT_RECURSIVE)
+    tc = api.converted_call(tmc, (), None, options=DEFAULT_RECURSIVE)
     self.assertIsInstance(tc, tmc)
 
   def test_converted_call_callable_abc(self):
@@ -500,6 +497,15 @@ class ApiTest(test.TestCase):
     ag_logging.set_verbosity(0, False)
     os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
 
+  def test_converted_call_partial_of_whitelisted_method(self):
+
+    def test_fn(_):
+      self.assertFalse(converter_testing.is_inside_generated_code())
+
+    converter_testing.whitelist(test_fn)
+    api.converted_call(
+        functools.partial(test_fn, None), (), None, options=DEFAULT_RECURSIVE)
+
   def test_converted_call_already_converted(self):
 
     def f(x):
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index 47c52d2e8bb..ca9a0c9ea5d 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -93,6 +93,26 @@ def isbuiltin(f):
     return False
 
 
+def isconstructor(cls):
+  """Returns True if the argument is an object constructor.
+
+  In general, any object of type class is a constructor, with the exception
+  of classes created using a callable metaclass.
+  See below for why a callable metaclass is not a trivial combination:
+  https://docs.python.org/2.7/reference/datamodel.html#customizing-class-creation
+
+  Args:
+    cls: Any
+  Returns:
+    Bool
+  """
+  return (
+      inspect.isclass(cls)
+      and not (issubclass(cls.__class__, type)
+               and hasattr(cls.__class__, '__call__')
+               and cls.__class__.__call__ is not type.__call__))
+
+
 def _fix_linecache_record(obj):
   """Fixes potential corruption of linecache in the presence of functools.wraps.
 
@@ -351,4 +371,3 @@ def getfutureimports(entity):
     return tuple()
   return tuple(sorted(name for name, value in entity.__globals__.items()
                       if getattr(value, '__module__', None) == '__future__'))
-
diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
index f8bd427becc..93b7d8237c5 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import abc
 import collections
 import functools
 import imp
@@ -517,14 +518,14 @@ class InspectUtilsTest(test.TestCase):
       def baz(self):
         pass
 
-    self.assertTrue(
-        inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass)
-    self.assertTrue(
-        inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass)
-    self.assertTrue(
-        inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass)
-    self.assertTrue(
-        inspect_utils.getdefiningclass(Subclass.class_method, Subclass) is
+    self.assertIs(
+        inspect_utils.getdefiningclass(Subclass.foo, Subclass), Subclass)
+    self.assertIs(
+        inspect_utils.getdefiningclass(Subclass.bar, Subclass), Superclass)
+    self.assertIs(
+        inspect_utils.getdefiningclass(Subclass.baz, Subclass), Subclass)
+    self.assertIs(
+        inspect_utils.getdefiningclass(Subclass.class_method, Subclass),
         Superclass)
 
   def test_isbuiltin(self):
@@ -537,6 +538,53 @@ class InspectUtilsTest(test.TestCase):
     self.assertTrue(inspect_utils.isbuiltin(zip))
     self.assertFalse(inspect_utils.isbuiltin(function_decorator))
 
+  def test_isconstructor(self):
+
+    class OrdinaryClass(object):
+      pass
+
+    class OrdinaryCallableClass(object):
+
+      def __call__(self):
+        pass
+
+    class Metaclass(type):
+      pass
+
+    class CallableMetaclass(type):
+
+      def __call__(cls):
+        pass
+
+    self.assertTrue(inspect_utils.isconstructor(OrdinaryClass))
+    self.assertTrue(inspect_utils.isconstructor(OrdinaryCallableClass))
+    self.assertTrue(inspect_utils.isconstructor(Metaclass))
+    self.assertTrue(inspect_utils.isconstructor(Metaclass('TestClass', (), {})))
+    self.assertTrue(inspect_utils.isconstructor(CallableMetaclass))
+
+    self.assertFalse(inspect_utils.isconstructor(
+        CallableMetaclass('TestClass', (), {})))
+
+  def test_isconstructor_abc_callable(self):
+
+    @six.add_metaclass(abc.ABCMeta)
+    class AbcBase(object):
+
+      @abc.abstractmethod
+      def __call__(self):
+        pass
+
+    class AbcSubclass(AbcBase):
+
+      def __init__(self):
+        pass
+
+      def __call__(self):
+        pass
+
+    self.assertTrue(inspect_utils.isconstructor(AbcBase))
+    self.assertTrue(inspect_utils.isconstructor(AbcSubclass))
+
   def test_getfutureimports_functions(self):
     self.assertEqual(
         inspect_utils.getfutureimports(basic_definitions.function_with_print),

From dcef9de84fbdf89d6038de6b050fa25819f28b4b Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Mon, 2 Dec 2019 12:26:34 -0800
Subject: [PATCH 162/279] NFC: use `&&` instead of `and`

PiperOrigin-RevId: 283392575
Change-Id: Ib61f12d7617e929fc86e2c23c3c89eef359a0a6d
---
 third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index dcecd1c65be..16894ad4cb3 100644
--- a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1261,7 +1261,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
       if (emitNotNullCheck) {
         body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
       }
-      if (isRawValueAttr and canUseUnwrappedRawValue(attr)) {
+      if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
         // If this is a raw value, then we need to wrap it in an Attribute
         // instance.
         FmtContext fctx;

From 654dc1e68c0ff8d15f7a34b24479a81e8ccf0bac Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 12:54:13 -0800
Subject: [PATCH 163/279] Internal change

PiperOrigin-RevId: 283397905
Change-Id: Idcd37af269be680c9ba3e56142aae5d58aea21f9
---
 tensorflow/BUILD | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index ec879fe2c45..0f299ec13f8 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -2,10 +2,7 @@
 # TensorFlow is a computational framework, primarily for use in machine
 # learning applications.
 
-load("//tensorflow:tensorflow.bzl", "VERSION")
-load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
-load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
+load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
 load(
     "//tensorflow/core/platform:build_config.bzl",
     "tf_additional_binary_deps",
@@ -450,6 +447,7 @@ config_setting(
 package_group(
     name = "internal",
     packages = [
+        "//learning/brain/swift/x10/...",
         "//perftools/accelerators/xprof/api/...",
         "//tensorflow/...",
         "//tensorflow_estimator/python/estimator/...",

From 143e80f663573706c9f8759316becb751ef40e30 Mon Sep 17 00:00:00 2001
From: Jian Li 
Date: Mon, 2 Dec 2019 13:00:12 -0800
Subject: [PATCH 164/279] Add comment to shared range method in quantization.
 That method is needed because of restrictions in TFLite's LSTM kernel. It was
 not very clear in the original comments.

PiperOrigin-RevId: 283399107
Change-Id: I51dbcdb49daeaf950d550405d4c31f918cf52ebe
---
 tensorflow/lite/tools/optimize/quantize_model.cc | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc
index 304aad618d8..42db16eb965 100644
--- a/tensorflow/lite/tools/optimize/quantize_model.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model.cc
@@ -711,6 +711,11 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
 // Quantize tensros that have shared range. For example, in LSTM, the output
 // tensor and input state tensor should share the same range because they are
 // using the same scale and zero point.
+// We have to model this explicitely because the output is modeled as an extra
+// tensor in LSTM. In calibrator, state tensors are logged both before and after
+// the inferece so the range is fully captured. But output, although it is
+// identical to activation, is not a state tensor the input value (range) of the
+// very first inference is not captured.
 TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) {
   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
        subgraph_idx++) {

From 60cc1188702598e5284698bf890fc18f15e2c27f Mon Sep 17 00:00:00 2001
From: Yu-Cheng Ling 
Date: Mon, 2 Dec 2019 13:09:33 -0800
Subject: [PATCH 165/279] Explicitly disable new converter for "fully_quantize"
 testing code

PiperOrigin-RevId: 283401149
Change-Id: Ib3f36a4d45abb7660a1beb7f905ed798ae4ccf64
---
 tensorflow/lite/testing/toco_convert.py | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/tensorflow/lite/testing/toco_convert.py b/tensorflow/lite/testing/toco_convert.py
index f4072b241a0..3e8a489c5f8 100644
--- a/tensorflow/lite/testing/toco_convert.py
+++ b/tensorflow/lite/testing/toco_convert.py
@@ -112,9 +112,16 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
       graphdef_file.flush()
 
       input_shapes = zip_test_utils.get_input_shapes_map(input_tensors)
-      converter = tf.compat.v1.lite.TocoConverter.from_frozen_graph(
+      converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
           graphdef_file.name, input_arrays, output_tensors, input_shapes)
 
+      # TODO(b/145313371): Evaluate should we make it work with the new
+      # converter.
+      # Note: Currently this line is a non-functional change because the new
+      # converter is disabled by default. Since this code path doesn't work
+      # with new converter yet, it's explicitly disabled for easier testing.
+      converter.experimental_new_converter = False
+
       def representative_dataset(input_tensors):
         calibration_inputs = []
         for _, shape, _ in input_tensors:

From c77c0f8176a066bacf9101a3816bd93d53a29751 Mon Sep 17 00:00:00 2001
From: Renjie Liu 
Date: Mon, 2 Dec 2019 13:18:19 -0800
Subject: [PATCH 166/279] Fix lstm tests & also reduce test time.

PiperOrigin-RevId: 283402882
Change-Id: I0a8a4e352586bedc06be2c79286efb04c8014f18
---
 tensorflow/lite/experimental/examples/lstm/BUILD   |  9 +++++----
 .../lstm/bidirectional_sequence_lstm_test.py       | 14 ++++++++++----
 .../lstm/bidirectional_sequence_rnn_test.py        | 14 ++++++++++----
 .../lstm/unidirectional_sequence_lstm_test.py      | 14 ++++++++++----
 .../lstm/unidirectional_sequence_rnn_test.py       | 14 ++++++++++----
 5 files changed, 45 insertions(+), 20 deletions(-)

diff --git a/tensorflow/lite/experimental/examples/lstm/BUILD b/tensorflow/lite/experimental/examples/lstm/BUILD
index 2531889dafb..719e59c6a8c 100644
--- a/tensorflow/lite/experimental/examples/lstm/BUILD
+++ b/tensorflow/lite/experimental/examples/lstm/BUILD
@@ -35,7 +35,7 @@ py_library(
 
 py_test(
     name = "unidirectional_sequence_lstm_test",
-    size = "large",
+    size = "medium",
     srcs = ["unidirectional_sequence_lstm_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
@@ -58,7 +58,7 @@ py_test(
 
 py_test(
     name = "unidirectional_sequence_rnn_test",
-    size = "large",
+    size = "medium",
     srcs = ["unidirectional_sequence_rnn_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
@@ -81,7 +81,7 @@ py_test(
 
 py_test(
     name = "bidirectional_sequence_lstm_test",
-    size = "large",
+    size = "medium",
     srcs = ["bidirectional_sequence_lstm_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
@@ -104,13 +104,14 @@ py_test(
 
 py_test(
     name = "bidirectional_sequence_rnn_test",
-    size = "large",
+    size = "medium",
     srcs = ["bidirectional_sequence_rnn_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
     tags = [
         "no_oss",
         "no_pip",
+        "notap",  # b/141373014
     ],
     deps = [
         ":rnn",
diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
index f04a265714f..d4b5e2b663a 100644
--- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
@@ -27,7 +27,9 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 # Number of steps to train model.
-TRAIN_STEPS = 1
+# Dial to 0 means no training at all, all the weights will be just using their
+# initial values. This can help make the test smaller.
+TRAIN_STEPS = 0
 
 CONFIG = tf.ConfigProto(device_count={"GPU": 0})
 
@@ -37,7 +39,8 @@ class BidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
   def setUp(self):
     tf.reset_default_graph()
     # Import MNIST dataset
-    self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
+    self.mnist = input_data.read_data_sets(
+        "/tmp/data/", fake_data=True, one_hot=True)
 
     # Define constants
     # Unrolled through 28 time steps
@@ -144,8 +147,10 @@ class BidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
     sess.run(init)
     for _ in range(TRAIN_STEPS):
       batch_x, batch_y = self.mnist.train.next_batch(
-          batch_size=self.batch_size, shuffle=False)
+          batch_size=self.batch_size, fake_data=True)
 
+      batch_x = np.array(batch_x)
+      batch_y = np.array(batch_y)
       batch_x = batch_x.reshape((self.batch_size, self.time_steps,
                                  self.n_input))
       sess.run(opt, feed_dict={x: batch_x, y: batch_y})
@@ -200,7 +205,8 @@ class BidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
       - Expected output.
 
     """
-    b1, _ = self.mnist.train.next_batch(batch_size=1)
+    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
+    b1 = np.array(b1, dtype=np.dtype("float32"))
     sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
 
     expected_output = sess.run(output_class, feed_dict={x: sample_input})
diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
index 606f969b92a..b90d4d52b29 100644
--- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
@@ -31,7 +31,9 @@ from tensorflow.python.platform import test
 FLAGS = flags.FLAGS
 
 # Number of steps to train model.
-TRAIN_STEPS = 1
+# Dial to 0 means no training at all, all the weights will be just using their
+# initial values. This can help make the test smaller.
+TRAIN_STEPS = 0
 
 CONFIG = tf.ConfigProto(device_count={"GPU": 0})
 
@@ -58,7 +60,8 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
     super(BidirectionalSequenceRnnTest, self).setUp()
     # Import MNIST dataset
     data_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)
-    self.mnist = input_data.read_data_sets(data_dir, one_hot=True)
+    self.mnist = input_data.read_data_sets(
+        data_dir, fake_data=True, one_hot=True)
 
   def buildRnnLayer(self):
     return tf.keras.layers.StackedRNNCells([
@@ -165,8 +168,10 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
     sess.run(init)
     for _ in range(TRAIN_STEPS):
       batch_x, batch_y = self.mnist.train.next_batch(
-          batch_size=self.batch_size, shuffle=False)
+          batch_size=self.batch_size, shuffle=False, fake_data=True)
 
+      batch_x = np.array(batch_x)
+      batch_y = np.array(batch_y)
       batch_x = batch_x.reshape((self.batch_size, self.time_steps,
                                  self.n_input))
       sess.run(opt, feed_dict={x: batch_x, y: batch_y})
@@ -228,7 +233,8 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
       - Expected output.
 
     """
-    b1, _ = self.mnist.train.next_batch(batch_size=1)
+    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
+    b1 = np.array(b1, dtype=np.dtype("float32"))
     sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
 
     expected_output = sess.run(output_class, feed_dict={x: sample_input})
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
index d937a111529..ba936a4e8cd 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
@@ -27,7 +27,9 @@ from tensorflow.python.platform import test
 
 
 # Number of steps to train model.
-TRAIN_STEPS = 1
+# Dial to 0 means no training at all, all the weights will be just using their
+# initial values. This can help make the test smaller.
+TRAIN_STEPS = 0
 
 CONFIG = tf.ConfigProto(device_count={"GPU": 0})
 
@@ -37,7 +39,8 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
   def setUp(self):
     tf.reset_default_graph()
     # Import MNIST dataset
-    self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
+    self.mnist = input_data.read_data_sets(
+        "/tmp/data/", fake_data=True, one_hot=True)
 
     # Define constants
     # Unrolled through 28 time steps
@@ -133,8 +136,10 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
     sess.run(init)
     for _ in range(TRAIN_STEPS):
       batch_x, batch_y = self.mnist.train.next_batch(
-          batch_size=self.batch_size, shuffle=False)
+          batch_size=self.batch_size, fake_data=True)
 
+      batch_x = np.array(batch_x)
+      batch_y = np.array(batch_y)
       batch_x = batch_x.reshape((self.batch_size, self.time_steps,
                                  self.n_input))
       sess.run(opt, feed_dict={x: batch_x, y: batch_y})
@@ -184,7 +189,8 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
       - Expected output.
 
     """
-    b1, _ = self.mnist.train.next_batch(batch_size=1)
+    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
+    b1 = np.array(b1, dtype=np.dtype("float32"))
     sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
 
     expected_output = sess.run(output_class, feed_dict={x: sample_input})
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
index a3859e1ad40..49c3d5e7757 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
@@ -30,7 +30,9 @@ from tensorflow.python.platform import test
 FLAGS = flags.FLAGS
 
 # Number of steps to train model.
-TRAIN_STEPS = 1
+# Dial to 0 means no training at all, all the weights will be just using their
+# initial values. This can help make the test smaller.
+TRAIN_STEPS = 0
 
 CONFIG = tf.ConfigProto(device_count={"GPU": 0})
 
@@ -57,7 +59,8 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
     super(UnidirectionalSequenceRnnTest, self).setUp()
     # Import MNIST dataset
     data_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)
-    self.mnist = input_data.read_data_sets(data_dir, one_hot=True)
+    self.mnist = input_data.read_data_sets(
+        data_dir, fake_data=True, one_hot=True)
 
   def buildRnnLayer(self):
     return tf.keras.layers.StackedRNNCells([
@@ -128,8 +131,10 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
     sess.run(tf.global_variables_initializer())
     for _ in range(TRAIN_STEPS):
       batch_x, batch_y = self.mnist.train.next_batch(
-          batch_size=self.batch_size, shuffle=False)
+          batch_size=self.batch_size, fake_data=True)
 
+      batch_x = np.array(batch_x)
+      batch_y = np.array(batch_y)
       batch_x = batch_x.reshape((self.batch_size, self.time_steps,
                                  self.n_input))
       sess.run(opt, feed_dict={x: batch_x, y: batch_y})
@@ -179,7 +184,8 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
       - Expected output.
 
     """
-    b1, _ = self.mnist.train.next_batch(batch_size=1)
+    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
+    b1 = np.array(b1, dtype=np.dtype("float32"))
     sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
 
     expected_output = sess.run(output_class, feed_dict={x: sample_input})

From feac6d2f33d26878eaf346b920a0cd680e9d69ed Mon Sep 17 00:00:00 2001
From: Yunlu Li 
Date: Mon, 2 Dec 2019 13:26:33 -0800
Subject: [PATCH 167/279] Better quantized op tests.

PiperOrigin-RevId: 283404748
Change-Id: Ief89d65cf9632783b3e602c8550bf4381d3d4705
---
 .../testing/generated_examples_zip_test.cc    | 25 ++++--
 tensorflow/lite/testing/tflite_driver.cc      | 85 +++++++++++++++++--
 tensorflow/lite/testing/tflite_driver.h       |  2 +
 tensorflow/lite/testing/tflite_driver_test.cc |  2 +-
 tensorflow/lite/testing/toco_convert.py       |  2 +
 5 files changed, 104 insertions(+), 12 deletions(-)

diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc
index 16b4675bb0d..d1b3d267eba 100644
--- a/tensorflow/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/lite/testing/generated_examples_zip_test.cc
@@ -156,14 +156,22 @@ const std::map& GetKnownBrokenNnapiTests() {
 const std::map& GetKnownQuantizeBrokenTests() {
   static const std::map* const kQuantizeBrokenTests =
       new std::map({
-          {R"(^\/conv.*fully_quantize=True)", "134594898"},
-          {R"(^\/depthwiseconv.*fully_quantize=True)", "134594898"},
           {R"(^\/sum.*fully_quantize=True)", "134594898"},
           {R"(^\/l2norm.*fully_quantize=True)", "134594898"},
       });
   return *kQuantizeBrokenTests;
 }
 
+const std::map& GetQuantizeTestsError() {
+  static const std::map* const kQuantizeBrokenTests =
+      new std::map({
+          {R"(^\/conv_relu1.*fully_quantize=True)", 18},
+          {R"(^\/conv_relu6.*fully_quantize=True)", 8},
+          {R"(^\/maximum.*fully_quantize=True)", 8},
+      });
+  return *kQuantizeBrokenTests;
+}
+
 // Allows test data to be unarchived into a temporary directory and makes
 // sure those temporary directories are removed later.
 class ArchiveEnvironment : public ::testing::Environment {
@@ -299,10 +307,16 @@ TEST_P(OpsTest, RunZipTests) {
   tflite::testing::TfLiteDriver test_driver(
       FLAGS_use_nnapi ? TfLiteDriver::DelegateType::kNnapi
                       : TfLiteDriver::DelegateType::kNone);
+
+  auto quantized_tests_error = GetQuantizeTestsError();
   bool fully_quantize = false;
   if (test_path.find("fully_quantize=True") != std::string::npos) {
-    // TODO(b/134594898): Tighten this constraint.
-    test_driver.SetThreshold(0.2, 0.1);
+    for (const auto& p : quantized_tests_error) {
+      if (RE2::PartialMatch(test_name, p.first)) {
+        test_driver.SetQuantizationErrorMultiplier(p.second);
+        break;
+      }
+    }
     fully_quantize = true;
   }
 
@@ -313,7 +327,6 @@ TEST_P(OpsTest, RunZipTests) {
     auto kBrokenNnapiTests = GetKnownBrokenNnapiTests();
     broken_tests.insert(kBrokenNnapiTests.begin(), kBrokenNnapiTests.end());
   }
-  auto quantize_broken_tests = GetKnownQuantizeBrokenTests();
 
   bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver);
   string message = test_driver.GetErrorMessage();
@@ -346,7 +359,7 @@ TEST_P(OpsTest, RunZipTests) {
     if (!result) {
       string bug_number;
       // See if the tests are potential quantize failures.
-      for (const auto& p : quantize_broken_tests) {
+      for (const auto& p : GetKnownQuantizeBrokenTests()) {
         if (RE2::PartialMatch(test_name, p.first)) {
           bug_number = p.second;
           break;
diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc
index 795fb1fee99..47293016ab6 100644
--- a/tensorflow/lite/testing/tflite_driver.cc
+++ b/tensorflow/lite/testing/tflite_driver.cc
@@ -18,7 +18,6 @@ limitations under the License.
 #include 
 #include 
 #include 
-
 #include "absl/strings/escaping.h"
 #include "tensorflow/lite/builtin_op_data.h"
 #include "tensorflow/lite/delegates/flex/delegate.h"
@@ -37,6 +36,22 @@ namespace {
 const double kRelativeThreshold = 1e-2f;
 const double kAbsoluteThreshold = 1e-4f;
 
+// For quantized tests, we use a different error measurement from float ones.
+// Assumes the baseline is a always a float TF model.
+// Error of a quantized model compared to the baseline comes from two sources:
+//   1. the math done with quantized inputs, and
+//   2. quantization of the output.
+// Assumes there is no error introduced by source 1, the theoretical maximum
+// error allowed for the output is 0.5 * scale, because scale is equal to the
+// size of the quantization bucket.
+//
+// As a result, we use `scale` as a unit for measuring the quantization error.
+// To add the error introduced by source 1 as well, we need to relax the
+// multiplier from 0.5 to a larger number, which is model/op dependent.
+// The number below is good enough to account for both the two sources of error
+// for most quantized op tests to pass.
+const int kQuantizationErrorMultiplier = 4;
+
 // Returns the value in the given position in a tensor.
 template 
 T Value(void* data, int index) {
@@ -58,15 +73,31 @@ unique_void_ptr make_type_erased_array(size_t size) {
                          [](void* data) { delete[] static_cast(data); });
 }
 
+bool IsQuantized(const TfLiteTensor& tensor) {
+  if (tensor.type != kTfLiteInt8) return false;
+
+  if (tensor.quantization.params != nullptr) {
+    auto* quantization =
+        reinterpret_cast(tensor.quantization.params);
+    if (quantization->scale != nullptr && quantization->scale->size == 1 &&
+        quantization->zero_point != nullptr &&
+        quantization->zero_point->size == 1) {
+      return true;
+    }
+  }
+  return false;
+}
 }  // namespace
 
 class TfLiteDriver::DataExpectation {
  public:
-  DataExpectation(double relative_threshold, double absolute_threshold)
+  DataExpectation(double relative_threshold, double absolute_threshold,
+                  int quantization_error_multiplier)
       : data_(nullptr, nullptr),
         num_elements_(0),
         relative_threshold_(relative_threshold),
-        absolute_threshold_(absolute_threshold) {}
+        absolute_threshold_(absolute_threshold),
+        quantization_error_multiplier_(quantization_error_multiplier) {}
 
   template 
   void SetData(const string& csv_values) {
@@ -128,11 +159,13 @@ class TfLiteDriver::DataExpectation {
   }
 
   bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
+  bool QuantizedCheck(bool verbose, const TfLiteTensor& tensor);
 
   unique_void_ptr data_;
   size_t num_elements_;
   double relative_threshold_;
   double absolute_threshold_;
+  int quantization_error_multiplier_;
 };
 
 class TfLiteDriver::ShapeExpectation {
@@ -218,8 +251,37 @@ bool TfLiteDriver::DataExpectation::TypedCheckString(
   return true;
 }
 
+bool TfLiteDriver::DataExpectation::QuantizedCheck(bool verbose,
+                                                   const TfLiteTensor& tensor) {
+  auto* quantization =
+      reinterpret_cast(tensor.quantization.params);
+  const float scale = quantization->scale->data[0];
+  const int32 zero_point = quantization->zero_point->data[0];
+
+  bool good_result = true;
+  for (int i = 0; i < tensor.bytes; i++) {
+    const int32 computed = tensor.data.int8[i];
+    const float dequantized =
+        static_cast(scale * (computed - zero_point));
+    const float reference = Value(data_.get(), i);
+    if (std::abs(dequantized - reference) >
+        quantization_error_multiplier_ * scale) {
+      if (verbose) {
+        std::cerr << "  index " << i << ": got " << dequantized
+                  << ", but expected " << reference << std::endl;
+      }
+      good_result = false;
+    }
+  }
+  return good_result;
+}
+
 bool TfLiteDriver::DataExpectation::Check(bool verbose,
                                           const TfLiteTensor& tensor) {
+  if (IsQuantized(tensor)) {
+    return QuantizedCheck(verbose, tensor);
+  }
+
   switch (tensor.type) {
     case kTfLiteFloat32:
       return TypedCheck(verbose, tensor);
@@ -247,7 +309,8 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
 TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
     : delegate_(nullptr, nullptr),
       relative_threshold_(kRelativeThreshold),
-      absolute_threshold_(kAbsoluteThreshold) {
+      absolute_threshold_(kAbsoluteThreshold),
+      quantization_error_multiplier_(kQuantizationErrorMultiplier) {
   if (reference_kernel) {
     resolver_.reset(new ops::builtin::BuiltinRefOpResolver);
   } else {
@@ -395,6 +458,11 @@ void TfLiteDriver::SetThreshold(double relative_threshold,
   absolute_threshold_ = absolute_threshold;
 }
 
+void TfLiteDriver::SetQuantizationErrorMultiplier(
+    int quantization_error_multiplier) {
+  quantization_error_multiplier_ = quantization_error_multiplier;
+}
+
 void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
   if (!IsValid()) return;
   auto* tensor = interpreter_->tensor(id);
@@ -402,7 +470,14 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
     Invalidate(absl::StrCat("Overridden expectation for tensor '", id, "'"));
   }
   expected_output_[id].reset(
-      new DataExpectation(relative_threshold_, absolute_threshold_));
+      new DataExpectation(relative_threshold_, absolute_threshold_,
+                          quantization_error_multiplier_));
+
+  if (IsQuantized(*tensor)) {
+    expected_output_[id]->SetData(csv_values);
+    return;
+  }
+
   switch (tensor->type) {
     case kTfLiteFloat32:
       expected_output_[id]->SetData(csv_values);
diff --git a/tensorflow/lite/testing/tflite_driver.h b/tensorflow/lite/testing/tflite_driver.h
index ae843d1cba7..258902606a5 100644
--- a/tensorflow/lite/testing/tflite_driver.h
+++ b/tensorflow/lite/testing/tflite_driver.h
@@ -64,6 +64,7 @@ class TfLiteDriver : public TestRunner {
   bool CheckResults() override;
   string ReadOutput(int id) override;
   void SetThreshold(double relative_threshold, double absolute_threshold);
+  void SetQuantizationErrorMultiplier(int quantization_error_multiplier);
 
  protected:
   Interpreter::TfLiteDelegatePtr delegate_;
@@ -95,6 +96,7 @@ class TfLiteDriver : public TestRunner {
   std::map tensors_to_deallocate_;
   double relative_threshold_;
   double absolute_threshold_;
+  int quantization_error_multiplier_;
 };
 
 }  // namespace testing
diff --git a/tensorflow/lite/testing/tflite_driver_test.cc b/tensorflow/lite/testing/tflite_driver_test.cc
index 99efd2d66d1..6dac9565dde 100644
--- a/tensorflow/lite/testing/tflite_driver_test.cc
+++ b/tensorflow/lite/testing/tflite_driver_test.cc
@@ -112,7 +112,7 @@ TEST(TfliteDriverTest, AddQuantizedInt8Test) {
 
   runner->SetInput(1, "1,1,1,1");
 
-  runner->SetExpectation(2, "3,3,3,3");
+  runner->SetExpectation(2, "0.0117,0.0117,0.0117,0.0117");
 
   runner->Invoke();
   ASSERT_TRUE(runner->IsValid());
diff --git a/tensorflow/lite/testing/toco_convert.py b/tensorflow/lite/testing/toco_convert.py
index 3e8a489c5f8..e8d1e8eec12 100644
--- a/tensorflow/lite/testing/toco_convert.py
+++ b/tensorflow/lite/testing/toco_convert.py
@@ -146,6 +146,8 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
       if extra_toco_options.inference_output_type:
         converter.inference_output_type = (
             extra_toco_options.inference_output_type)
+      else:
+        converter.inference_output_type = tf.int8
 
       try:
         tflite_model = converter.convert()

From fc7a817dbd4a931bad458fa5896c8ebdba15d60f Mon Sep 17 00:00:00 2001
From: Brian Zhao 
Date: Mon, 2 Dec 2019 13:27:38 -0800
Subject: [PATCH 168/279] Change the client_secret that looks like a "private
 key" to something more obviously fake.

PiperOrigin-RevId: 283404936
Change-Id: Ide7c9816e11c7f18debce9f3eca504e79c1ce47b
---
 tensorflow/core/platform/cloud/oauth_client_test.cc | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 88e11a58d36..890e75a7036 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -64,7 +64,7 @@ TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) {
   const string credentials_json = R"(
       {
         "client_id": "test_client_id",
-        "client_secret": "test_client_secret",
+        "client_secret": "@@@test_client_secret@@@",
         "refresh_token": "test_refresh_token",
         "type": "authorized_user"
       })";
@@ -75,7 +75,7 @@ TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) {
   std::vector requests({new FakeHttpRequest(
       "Uri: https://www.googleapis.com/oauth2/v3/token\n"
       "Post body: client_id=test_client_id&"
-      "client_secret=test_client_secret&"
+      "client_secret=@@@test_client_secret@@@&"
       "refresh_token=test_refresh_token&grant_type=refresh_token\n",
       kTokenJson)});
   FakeEnv env;

From 1868de838d1e21a80155fa50138451241e9cdc61 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Mon, 2 Dec 2019 13:31:04 -0800
Subject: [PATCH 169/279] Use bazel's platform specific config options to
 switch between compiler options.

Documentation:
https://docs.bazel.build/versions/master/command-line-reference.html#flag--enable_platform_specific_config
PiperOrigin-RevId: 283405622
Change-Id: I38709761dcc80c68a3ae3b301a3bc15ab5a43d51
---
 .bazelrc     | 70 +++++++++++++++++++++++++++++++++++-----------------
 configure.py | 14 -----------
 2 files changed, 47 insertions(+), 37 deletions(-)

diff --git a/.bazelrc b/.bazelrc
index 7219e9e23c2..5fd28e867c0 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -202,10 +202,6 @@ build --define=allow_oversize_protos=true
 build --spawn_strategy=standalone
 build -c opt
 
-# By default, build TF in C++ 14 mode.
-build --cxxopt=-std=c++14
-build --host_cxxopt=-std=c++14
-
 # Make Bazel print out all options from rc files.
 build --announce_rc
 
@@ -235,13 +231,55 @@ build:c++17 --cxxopt=-std=c++1z
 build:c++17 --cxxopt=-stdlib=libc++
 build:c++1z --config=c++17
 
-# Default paths for TF_SYSTEM_LIBS
-build --define=PREFIX=/usr
-build --define=LIBDIR=$(PREFIX)/lib
-build --define=INCLUDEDIR=$(PREFIX)/include
+# Enable using platform specific build settings
+build --enable_platform_specific_config
 
 # Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
-build --copt=-w
+build:linux --copt=-w
+build:macos --copt=-w
+build:windows --copt=/w
+
+# Default paths for TF_SYSTEM_LIBS
+build:linux --define=PREFIX=/usr
+build:linux --define=LIBDIR=$(PREFIX)/lib
+build:linux --define=INCLUDEDIR=$(PREFIX)/include
+build:macos --define=PREFIX=/usr
+build:macos --define=LIBDIR=$(PREFIX)/lib
+build:macos --define=INCLUDEDIR=$(PREFIX)/include
+# TF_SYSTEM_LIBS do not work on windows.
+
+# By default, build TF in C++ 14 mode.
+build:linux --cxxopt=-std=c++14
+build:linux --host_cxxopt=-std=c++14
+build:macos --cxxopt=-std=c++14
+build:macos --host_cxxopt=-std=c++14
+build:windows --cxxopt=/std:c++14
+build:windows --host_cxxopt=/std:c++14
+
+# On windows, we still link everything into a single DLL.
+build:windows --config=monolithic
+
+# Make sure to include as little of windows.h as possible
+build:windows --copt=-DWIN32_LEAN_AND_MEAN
+build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
+build:windows --copt=-DNOGDI
+build:windows --host_copt=-DNOGDI
+
+# Misc build options we need for windows.
+build:windows --linkopt=/DEBUG
+build:windows --host_linkopt=/DEBUG
+build:windows --linkopt=/OPT:REF
+build:windows --host_linkopt=/OPT:REF
+build:windows --linkopt=/OPT:ICF
+build:windows --host_linkopt=/OPT:ICF
+build:windows --experimental_strict_action_env=true
+build:windows --incompatible_windows_native_test_wrapper
+
+# Verbose failure logs when something goes wrong
+build:windows --verbose_failures
+
+# On windows, we never cross compile
+build:windows --distinct_host_configuration=false
 
 # Suppress all warning messages.
 build:short_logs --output_filter=DONT_MATCH_ANYTHING
@@ -346,20 +384,6 @@ build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_
 build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
 build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
 
-# Misc build options we need for windows
-build:rbe_win --copt=-DWIN32_LEAN_AND_MEAN
-build:rbe_win --host_copt=-DWIN32_LEAN_AND_MEAN
-build:rbe_win --copt=-DNOGDI
-build:rbe_win --host_copt=-DNOGDI
-build:rbe_win --linkopt=/DEBUG
-build:rbe_win --host_linkopt=/DEBUG
-build:rbe_win --linkopt=/OPT:REF
-build:rbe_win --host_linkopt=/OPT:REF
-build:rbe_win --linkopt=/OPT:ICF
-build:rbe_win --host_linkopt=/OPT:ICF
-build:rbe_win --config=monolithic
-build:rbe_win --experimental_strict_action_env=true
-build:rbe_win --incompatible_windows_native_test_wrapper
 # TODO(gunan): Remove once we use MSVC 2019 with latest patches.
 build:rbe_win --define=override_eigen_strong_inline=true
 
diff --git a/configure.py b/configure.py
index 2c7914052e9..fedbd470f2d 100644
--- a/configure.py
+++ b/configure.py
@@ -1232,20 +1232,6 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
 
 def set_windows_build_flags(environ_cp):
   """Set Windows specific build options."""
-  # The non-monolithic build is not supported yet
-  write_to_bazelrc('build --config monolithic')
-  # Suppress warning messages
-  write_to_bazelrc('build --copt=-w --host_copt=-w')
-  # Fix winsock2.h conflicts
-  write_to_bazelrc(
-      'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN '
-      '--copt=-DNOGDI --host_copt=-DNOGDI')
-  # Output more verbose information when something goes wrong
-  write_to_bazelrc('build --verbose_failures')
-  # The host and target platforms are the same in Windows build. So we don't
-  # have to distinct them. This avoids building the same targets twice.
-  write_to_bazelrc('build --distinct_host_configuration=false')
-
   if is_reduced_optimize_huge_functions_available(environ_cp):
     write_to_bazelrc(
         'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'

From 7f347e6aa489e9e44f78feb4592f50d5d6bf94dc Mon Sep 17 00:00:00 2001
From: Saurabh Saxena 
Date: Mon, 2 Dec 2019 13:49:20 -0800
Subject: [PATCH 170/279] Correctly handle Nones in
 cond_v2._make_indexed_slices_indices_types_match.

PiperOrigin-RevId: 283409753
Change-Id: Ic83f8006693a404cbe4c723c0a103ffd98a76d1e
---
 .../python/kernel_tests/cond_v2_test.py       | 23 +++++++++++++++++++
 tensorflow/python/ops/cond_v2.py              | 13 +++++++----
 2 files changed, 31 insertions(+), 5 deletions(-)

diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index d5cb1c49555..60d34a0e299 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -119,6 +119,29 @@ class CondV2Test(test.TestCase):
     output = build_cond_with_indexed_slices()
     self.assertAllEqual(output, [1.])
 
+  def testReturnsNonesAndIndexedSlices(self):
+
+    @def_function.function
+    def build_cond_with_indexed_slices():
+      pred = constant_op.constant(True)
+
+      def true_fn():
+        return (None, None, None,
+                math_ops._as_indexed_slices(constant_op.constant([1.])))
+
+      def false_fn():
+        return (None, None, None,
+                math_ops._as_indexed_slices(constant_op.constant([2.])))
+
+      result = cond_v2.cond_v2(pred, true_fn, false_fn)
+      self.assertIsNone(result[0])
+      self.assertIsNone(result[1])
+      self.assertIsNone(result[2])
+      return ops.convert_to_tensor(result[3])
+
+    output = build_cond_with_indexed_slices()
+    self.assertAllEqual(output, [1.])
+
   def testExternalControlDependencies(self):
     with ops.Graph().as_default(), self.test_session():
       v = variables.Variable(1.0)
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index c0237d1bf9f..fd0102328d1 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -617,15 +617,18 @@ def _make_output_composite_tensors_match(op_type, branch_graphs):
 def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
   """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
   assert branch_graphs
+  # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`.
   indexed_slice_indices = []
   current_index = 0
+  # Note that this still contains Nones. We leave those in so that error
+  # messages contain the correct indices. We handle the Nones later when
+  # updating `current_index`.
   branch_outputs_flat_with_composites = [
       nest.flatten(branch_graph.structured_outputs, expand_composites=False)
       for branch_graph in branch_graphs
   ]
   outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
   assert len(set(outs_per_branch)) == 1, outs_per_branch
-  num_none_outputs = 0
   # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
   for output_idx, branch_outs in enumerate(
       zip(*branch_outputs_flat_with_composites)):
@@ -640,17 +643,17 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
       indexed_slice_indices.append(current_index + 1)
     if nest.is_sequence_or_composite(branch_outs[0]):
       current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
-    else:
+    elif branch_outs[0] is not None:
+      # `FuncGraph.outputs` does not contain Nones so no need to update the
+      # counter in that case.
       current_index += 1
-    if branch_outs[0] is None:
-      num_none_outputs += 1
 
   if not indexed_slice_indices:
     return
 
   # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus
   # the Nones.
-  if current_index != len(branch_graphs[0].outputs) + num_none_outputs:
+  if current_index != len(branch_graphs[0].outputs):
     raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
                      "Expected: %i\n"
                      "Actual: %i" %

From 4eb24e50a79ea40a6ef2a6ed2497638fe51ed094 Mon Sep 17 00:00:00 2001
From: Berkin Ilbeyi 
Date: Mon, 2 Dec 2019 13:52:40 -0800
Subject: [PATCH 171/279] [XLA] Fix a scheduling bug with evictions to default
 mem.

When simplifying the graph for dead code, we were previously removing the
deleted instruction from the schedule. However, the scheduler, which is run
after SimplifyGraph, relies on the original logical time (index into the
instruction schedule). So, when some instructions have been deleted, we end up
scheduling certain operation later than intended. Most seriously, the evictions
could have been scheduled later than they were supposed to, corrupting the
memory since we might have reused the evicted memory. The solution is to mark
the deleted instructions with a nullptr in the schedule instead of actually
deleting them.

PiperOrigin-RevId: 283410438
Change-Id: Ia671ede14469dd00ac218cfb8d714e171152bb37
---
 .../xla/service/memory_space_assignment.cc    | 28 ++++--
 .../xla/service/memory_space_assignment.h     |  5 +-
 .../service/memory_space_assignment_test.cc   | 93 +++++++++++++++++++
 3 files changed, 114 insertions(+), 12 deletions(-)

diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 751d258142a..28c93fb75fd 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -1130,7 +1130,15 @@ Status MemorySpaceAssignment::SimplifyGraph() {
           // Ensure the exported preset assignments don't contain a refence to
           // the removed instruction.
           preset_assignments_->RemoveAssignmentForInstruction(instruction);
-          flattened_instruction_sequence_.remove_instruction(instruction);
+          // Instead of deleting the instruction from the schedule, replace it
+          // with a nullptr. This is needed because FixSchedule relies on the
+          // logical time that is the index into flattened_instructions_ for
+          // scheduling asynchronous copies.
+          auto instruction_it =
+              absl::c_find(flattened_instructions_, instruction);
+          if (instruction_it != flattened_instructions_.end()) {
+            *instruction_it = nullptr;
+          }
           TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
           computation_modified = true;
         } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
@@ -1228,12 +1236,12 @@ void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
 
       // If the copy start doesn't happen to be scheduled at the correct
       // computation, delay it until the correct computation starts.
-      const auto& flattened_instructions =
-          flattened_instruction_sequence_.instructions();
       int64 copy_start_schedule_after =
           copy_allocation->copy_start_schedule_after();
+      // Accessing flattened_instructions_ here without checking if it is
+      // nullptr is safe because this method is called before SimplifyGraph.
       while (copy_allocation->instruction()->parent() !=
-             flattened_instructions[copy_start_schedule_after]->parent()) {
+             flattened_instructions_[copy_start_schedule_after]->parent()) {
         VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
                 << (copy_start_schedule_after + 1) << ") for "
                 << copy_allocation->copy_start()->ToString()
@@ -1264,8 +1272,7 @@ Status MemorySpaceAssignment::FixSchedule() {
     VLOG(4) << "Scheduling: " << computation->ToString();
 
     for (int64 instruction_index = 0;
-         instruction_index <
-         flattened_instruction_sequence_.instructions().size();
+         instruction_index < flattened_instructions_.size();
          ++instruction_index) {
       auto insts_before_iter = schedule_before_.find(instruction_index);
       if (insts_before_iter != schedule_before_.end()) {
@@ -1276,10 +1283,11 @@ Status MemorySpaceAssignment::FixSchedule() {
           }
         }
       }
-      HloInstruction* instruction =
-          flattened_instruction_sequence_.instructions()[instruction_index];
-      // Insert only if not previously inserted.
-      if (!inserted_instructions.contains(instruction) &&
+      HloInstruction* instruction = flattened_instructions_[instruction_index];
+      // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
+      // it was deleted) and not previously inserted.
+      if (instruction != nullptr &&
+          !inserted_instructions.contains(instruction) &&
           instruction->parent() == computation) {
         EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
                                              &inserted_instructions);
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index bfc91664bea..a8b3310cf24 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -450,7 +450,8 @@ class MemorySpaceAssignment {
       absl::Span flattened_instructions)
       : module_(module),
         alternate_memory_space_(alternate_memory_space),
-        flattened_instruction_sequence_(flattened_instructions),
+        flattened_instructions_(flattened_instructions.begin(),
+                                flattened_instructions.end()),
         preset_assignments_(absl::make_unique()) {}
 
   // Process calls Process methods of the allocations after the allocations have
@@ -479,7 +480,7 @@ class MemorySpaceAssignment {
 
   HloModule* module_;
   int64 alternate_memory_space_;
-  HloInstructionSequence flattened_instruction_sequence_;
+  std::vector flattened_instructions_;
   AllocationMap allocation_map_;
   std::unique_ptr preset_assignments_;
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 6041b96636e..7e3ce7dfbbd 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -2032,6 +2032,99 @@ TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
   EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
 }
 
+TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) {
+  // This test reproduces an eviction scheduling bug where evictions to default
+  // memory can happen later than intended, causing memory corruption. This test
+  // is a variant of MemoryBoundednessBufferIntervalCompare but uses f32[4,3]
+  // tensors instead, so at most two tensors should fit in the alternate memory
+  // space at a given time. We have a number of redundant operations
+  // (tanh_redundant ops) that do not have users. The bug was due to
+  // SimplifyGraph removing dead instructions, and removing them from the
+  // schedule. However, the CopyStart/CopyDone insertion relies on the schedule
+  // indexes, so they could be inserted too late.
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* tanh0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant4 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant5 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant6 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* negate0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, tanh0));
+  HloInstruction* tanh1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, negate0));
+  HloInstruction* negate1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* tanh2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
+  HloInstruction* negate2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* tanh3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
+  HloInstruction* negate3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* tuple = builder.AddInstruction(
+      HloInstruction::CreateTuple({tanh3, negate3, tanh0}));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(
+      computation,
+      {p0, tanh0, tanh_redundant0, tanh_redundant1, tanh_redundant2,
+       tanh_redundant3, tanh_redundant4, tanh_redundant5, tanh_redundant6,
+       negate0, tanh1, negate1, tanh2, negate2, tanh3, negate3, tuple});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpaceUsingCostAnalysis(module.get());
+
+  TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis,
+                          HloAliasAnalysis::Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range,
+                          HloLiveRange::Run(module->schedule(), *alias_analysis,
+                                            module->entry_computation()));
+
+  std::vector num_live_buffers_in_alternate_mem(
+      hlo_live_range->flattened_instruction_sequence().size() + 1, 0);
+
+  // Go through each value and for those that are allocated in the alternate
+  // memory space, increment (inclusive) num_live_buffers_in_alternate_mem for
+  // every time step that they are live.
+  for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
+    const Shape& shape = value->shape();
+    if (!shape.has_layout() ||
+        shape.layout().memory_space() == kDefaultMemorySpace) {
+      continue;
+    }
+
+    HloLiveRange::TimeBound time_bound =
+        hlo_live_range->buffer_live_ranges().at(value);
+    for (int i = time_bound.start; i <= time_bound.end; ++i) {
+      ++num_live_buffers_in_alternate_mem[i];
+    }
+  }
+
+  // The test memory can at most hold two f32[4,3] buffers at a time. If there
+  // is more than that, it means we have memory corruption.
+  for (int i = 0; i < num_live_buffers_in_alternate_mem.size(); ++i) {
+    EXPECT_LE(num_live_buffers_in_alternate_mem[i], 2);
+  }
+}
+
 INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
                          MemorySpaceAssignmentTest,
                          ::testing::Values(false, true));

From 2baf25e56383a7ca7c15390b0b1c3ebfabca1407 Mon Sep 17 00:00:00 2001
From: Dan Moldovan 
Date: Mon, 2 Dec 2019 14:02:10 -0800
Subject: [PATCH 172/279] Cleanup: use a more consistent naming scheme for the
 parsing/loading utilities.

PiperOrigin-RevId: 283412576
Change-Id: I3fd801cc2044e23aec0d57da7a17efdd5b3e3a1c
---
 .../autograph/converters/arg_defaults_test.py |  5 +-
 tensorflow/python/autograph/core/converter.py |  9 +--
 .../python/autograph/core/converter_test.py   |  4 +-
 .../autograph/core/converter_testing.py       |  6 +-
 tensorflow/python/autograph/impl/api_test.py  |  2 +-
 .../python/autograph/impl/conversion.py       |  8 +-
 .../python/autograph/impl/conversion_test.py  |  7 +-
 tensorflow/python/autograph/pyct/BUILD        |  6 +-
 .../python/autograph/pyct/ast_util_test.py    | 47 ++++++------
 tensorflow/python/autograph/pyct/cfg.py       |  7 +-
 .../pyct/common_transformers/anf_test.py      | 12 +--
 .../autograph/pyct/{compiler.py => loader.py} | 73 +++----------------
 .../pyct/{compiler_test.py => loader_test.py} | 42 +++--------
 .../python/autograph/pyct/origin_info.py      |  2 +-
 .../python/autograph/pyct/origin_info_test.py |  4 +-
 tensorflow/python/autograph/pyct/parser.py    | 64 ++++++++++++++--
 .../python/autograph/pyct/parser_test.py      | 28 +++++++
 .../python/autograph/pyct/qual_names_test.py  |  6 +-
 tensorflow/python/autograph/pyct/templates.py |  2 +-
 .../python/autograph/pyct/templates_test.py   | 18 ++---
 .../python/autograph/pyct/transformer.py      |  6 +-
 21 files changed, 177 insertions(+), 181 deletions(-)
 rename tensorflow/python/autograph/pyct/{compiler.py => loader.py} (55%)
 rename tensorflow/python/autograph/pyct/{compiler_test.py => loader_test.py} (71%)

diff --git a/tensorflow/python/autograph/converters/arg_defaults_test.py b/tensorflow/python/autograph/converters/arg_defaults_test.py
index 33dabe52839..6448f3124db 100644
--- a/tensorflow/python/autograph/converters/arg_defaults_test.py
+++ b/tensorflow/python/autograph/converters/arg_defaults_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 from tensorflow.python.autograph.converters import arg_defaults
 from tensorflow.python.autograph.core import converter_testing
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.platform import test
 
 
@@ -28,8 +28,7 @@ class ArgDefaultsTransformerTest(converter_testing.TestCase):
 
   def assertTransformedFirstLineIs(self, node, expected):
     self.assertEqual(
-        compiler.ast_to_source(node,
-                               include_encoding_marker=False).split('\n')[0],
+        parser.unparse(node, include_encoding_marker=False).split('\n')[0],
         expected)
 
   def test_no_args(self):
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index e9bf009d029..3102377d638 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -69,7 +69,6 @@ import enum
 from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.autograph.pyct import ast_util
 from tensorflow.python.autograph.pyct import cfg
-from tensorflow.python.autograph.pyct import compiler
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import qual_names
 from tensorflow.python.autograph.pyct import templates
@@ -329,10 +328,10 @@ class Base(transformer.Base):
     for other_value in arg_values_found[1:]:
       if not ast_util.matches(first_value, other_value):
         qn = anno.getanno(node, anno.Basic.QN)
-        raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
-                         (qn, directive.__name__, arg,
-                          compiler.ast_to_source(other_value).strip(),
-                          compiler.ast_to_source(first_value).strip()))
+        raise ValueError(
+            '%s has ambiguous annotations for %s(%s): %s, %s' %
+            (qn, directive.__name__, arg, parser.unparse(other_value).strip(),
+             parser.unparse(first_value).strip()))
     return first_value
 
   def visit(self, node):
diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py
index 2d5b33465e0..030ec761d95 100644
--- a/tensorflow/python/autograph/core/converter_test.py
+++ b/tensorflow/python/autograph/core/converter_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
 from tensorflow.python.autograph.core import converter
 from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.autograph.pyct import anno
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import loader
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import templates
 from tensorflow.python.platform import test
@@ -43,7 +43,7 @@ class ConversionOptionsTest(converter_testing.TestCase):
     '''
     opts_packed = templates.replace(template, opts_ast=opts_ast)
 
-    reparsed, _, _ = compiler.ast_to_object(opts_packed)
+    reparsed, _, _ = loader.load_ast(opts_packed)
     reparsed.__dict__['ag__'] = self.make_fake_mod(
         'fake_ag', converter.ConversionOptions, converter.Feature)
 
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index b11f210a951..4ea1187f8ed 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -32,7 +32,7 @@ from tensorflow.python.autograph.core import converter
 from tensorflow.python.autograph.core import function_wrappers
 from tensorflow.python.autograph.core import naming
 from tensorflow.python.autograph.lang import special_functions
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import loader
 from tensorflow.python.autograph.pyct import origin_info
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import pretty_printer
@@ -97,7 +97,7 @@ class TestCase(test.TestCase):
       return f(*args, **kwargs)
 
     try:
-      result, source, source_map = compiler.ast_to_object(
+      result, source, source_map = loader.load_ast(
           node, include_source_map=True)
       # TODO(mdan): Move the unparsing from converter into pyct and reuse here.
 
@@ -120,7 +120,7 @@ class TestCase(test.TestCase):
       if source is None:
         print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
       else:
-        print('Offending compiled code:\n%s' % source)
+        print('Offending source code:\n%s' % source)
       raise
 
   @contextlib.contextmanager
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index 0188af10e2e..e9b9fc75150 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -1012,7 +1012,7 @@ class ApiTest(test.TestCase):
       return x
 
     # Just check that the output is parseable Python code.
-    self.assertIsNotNone(parser.parse_str(api.to_code(test_fn)))
+    self.assertIsNotNone(parser.parse(api.to_code(test_fn)))
 
   def test_to_code_with_wrapped_function(self):
 
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index 4c8555eb293..c256b4e8e65 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -52,7 +52,7 @@ from tensorflow.python.autograph.core import naming
 from tensorflow.python.autograph.core import unsupported_features_checker
 from tensorflow.python.autograph.lang import special_functions
 from tensorflow.python.autograph.pyct import ast_util
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import loader
 from tensorflow.python.autograph.pyct import inspect_utils
 from tensorflow.python.autograph.pyct import origin_info
 from tensorflow.python.autograph.pyct import parser
@@ -282,8 +282,7 @@ def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
                                        free_nonglobal_var_names,
                                        entity_info.future_features)
 
-    module, _, source_map = compiler.ast_to_object(
-        nodes, include_source_map=True)
+    module, _, source_map = loader.load_ast(nodes, include_source_map=True)
     module_name = module.__name__
 
     converted_entity_info = _ConvertedEntityFactoryInfo(
@@ -519,8 +518,7 @@ def convert_entity_to_ast(o, program_ctx):
         'supported for now.' % (o, type(o)))
 
   if logging.has_verbosity(2):
-    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
-                compiler.ast_to_source(nodes))
+    logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes))
   if logging.has_verbosity(4):
     for n in nodes:
       logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py
index 4cdffe6d6e2..a6336ef0dab 100644
--- a/tensorflow/python/autograph/impl/conversion_test.py
+++ b/tensorflow/python/autograph/impl/conversion_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.autograph.core import config
 from tensorflow.python.autograph.core import converter
 from tensorflow.python.autograph.impl import api
 from tensorflow.python.autograph.impl import conversion
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.framework import constant_op
 from tensorflow.python.keras.engine import training
 from tensorflow.python.platform import test
@@ -128,9 +128,8 @@ class ConversionTest(test.TestCase):
     self.assertIsInstance(fn_node, gast.FunctionDef)
     self.assertEqual('tf__f', name)
     self.assertEqual(
-        compiler.ast_to_source(
-            fn_node.args.defaults[0], include_encoding_marker=False).strip(),
-        'None')
+        parser.unparse(fn_node.args.defaults[0],
+                       include_encoding_marker=False).strip(), 'None')
 
   def test_convert_entity_to_ast_call_tree(self):
 
diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD
index f7ae813c41d..b9931236428 100644
--- a/tensorflow/python/autograph/pyct/BUILD
+++ b/tensorflow/python/autograph/pyct/BUILD
@@ -25,10 +25,10 @@ py_library(
         "anno.py",
         "ast_util.py",
         "cfg.py",
-        "compiler.py",
         "error_utils.py",
         "errors.py",
         "inspect_utils.py",
+        "loader.py",
         "origin_info.py",
         "parser.py",
         "pretty_printer.py",
@@ -83,8 +83,8 @@ py_test(
 )
 
 py_test(
-    name = "compiler_test",
-    srcs = ["compiler_test.py"],
+    name = "loader_test",
+    srcs = ["loader_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
     deps = [
diff --git a/tensorflow/python/autograph/pyct/ast_util_test.py b/tensorflow/python/autograph/pyct/ast_util_test.py
index bc7c3f93ac5..7ed0f7b6b85 100644
--- a/tensorflow/python/autograph/pyct/ast_util_test.py
+++ b/tensorflow/python/autograph/pyct/ast_util_test.py
@@ -26,7 +26,7 @@ import gast
 
 from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.autograph.pyct import ast_util
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import loader
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import qual_names
 from tensorflow.python.platform import test
@@ -39,28 +39,28 @@ class AstUtilTest(test.TestCase):
     self._invocation_counts = collections.defaultdict(lambda: 0)
 
   def test_rename_symbols_basic(self):
-    node = parser.parse_str('a + b')
+    node = parser.parse('a + b')
     node = qual_names.resolve(node)
 
     node = ast_util.rename_symbols(
         node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
 
     self.assertIsInstance(node.value.left.id, str)
-    source = compiler.ast_to_source(node, include_encoding_marker=False)
+    source = parser.unparse(node, include_encoding_marker=False)
     self.assertEqual(source.strip(), 'renamed_a + b')
 
   def test_rename_symbols_attributes(self):
-    node = parser.parse_str('b.c = b.c.d')
+    node = parser.parse('b.c = b.c.d')
     node = qual_names.resolve(node)
 
     node = ast_util.rename_symbols(
         node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})
 
-    source = compiler.ast_to_source(node, include_encoding_marker=False)
+    source = parser.unparse(node, include_encoding_marker=False)
     self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
 
   def test_rename_symbols_annotations(self):
-    node = parser.parse_str('a[i]')
+    node = parser.parse('a[i]')
     node = qual_names.resolve(node)
     anno.setanno(node, 'foo', 'bar')
     orig_anno = anno.getanno(node, 'foo')
@@ -71,7 +71,7 @@ class AstUtilTest(test.TestCase):
     self.assertIs(anno.getanno(node, 'foo'), orig_anno)
 
   def test_copy_clean(self):
-    node = parser.parse_str(
+    node = parser.parse(
         textwrap.dedent("""
       def f(a):
         return a + 1
@@ -82,7 +82,7 @@ class AstUtilTest(test.TestCase):
     self.assertFalse(hasattr(new_node, '__foo'))
 
   def test_copy_clean_preserves_annotations(self):
-    node = parser.parse_str(
+    node = parser.parse(
         textwrap.dedent("""
       def f(a):
         return a + 1
@@ -98,9 +98,9 @@ class AstUtilTest(test.TestCase):
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
-    node = parser.parse_str('def f(b): pass')
+    node = parser.parse('def f(b): pass')
     node.body.append(ast.Return(d))
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
 
   def assertMatch(self, target_str, pattern_str):
@@ -131,12 +131,12 @@ class AstUtilTest(test.TestCase):
                        'super(Bar, _).__init__(_)')
 
   def _mock_apply_fn(self, target, source):
-    target = compiler.ast_to_source(target, include_encoding_marker=False)
-    source = compiler.ast_to_source(source, include_encoding_marker=False)
+    target = parser.unparse(target, include_encoding_marker=False)
+    source = parser.unparse(source, include_encoding_marker=False)
     self._invocation_counts[(target.strip(), source.strip())] += 1
 
   def test_apply_to_single_assignments_dynamic_unpack(self):
-    node = parser.parse_str('a, b, c = d')
+    node = parser.parse('a, b, c = d')
     ast_util.apply_to_single_assignments(node.targets, node.value,
                                          self._mock_apply_fn)
     self.assertDictEqual(self._invocation_counts, {
@@ -146,7 +146,7 @@ class AstUtilTest(test.TestCase):
     })
 
   def test_apply_to_single_assignments_static_unpack(self):
-    node = parser.parse_str('a, b, c = d, e, f')
+    node = parser.parse('a, b, c = d, e, f')
     ast_util.apply_to_single_assignments(node.targets, node.value,
                                          self._mock_apply_fn)
     self.assertDictEqual(self._invocation_counts, {
@@ -160,7 +160,7 @@ class AstUtilTest(test.TestCase):
       def f(a):
         return a + 1
     """
-    node = parser.parse_str(textwrap.dedent(src))
+    node = parser.parse(textwrap.dedent(src))
     for child_a, child_b in ast_util.parallel_walk(node, node):
       self.assertEqual(child_a, child_b)
 
@@ -169,22 +169,22 @@ class AstUtilTest(test.TestCase):
       def f(a):
         global g
     """
-    node = parser.parse_str(textwrap.dedent(src))
+    node = parser.parse(textwrap.dedent(src))
     for child_a, child_b in ast_util.parallel_walk(node, node):
       self.assertEqual(child_a, child_b)
 
   def test_parallel_walk_inconsistent_trees(self):
-    node_1 = parser.parse_str(
+    node_1 = parser.parse(
         textwrap.dedent("""
       def f(a):
         return a + 1
     """))
-    node_2 = parser.parse_str(
+    node_2 = parser.parse(
         textwrap.dedent("""
       def f(a):
         return a + (a * 2)
     """))
-    node_3 = parser.parse_str(
+    node_3 = parser.parse(
         textwrap.dedent("""
       def f(a):
         return a + 2
@@ -204,12 +204,11 @@ class AstUtilTest(test.TestCase):
     for node in matching_nodes:
       self.assertIsInstance(node, gast.Lambda)
       self.assertIn(
-          compiler.ast_to_source(node.body,
-                                 include_encoding_marker=False).strip(),
+          parser.unparse(node.body, include_encoding_marker=False).strip(),
           expected_bodies)
 
   def test_find_matching_definitions_lambda(self):
-    node = parser.parse_str(
+    node = parser.parse(
         textwrap.dedent("""
       f = lambda x: 1
     """))
@@ -218,7 +217,7 @@ class AstUtilTest(test.TestCase):
     self.assertLambdaNodes(nodes, ('(1)',))
 
   def test_find_matching_definitions_lambda_multiple_matches(self):
-    node = parser.parse_str(
+    node = parser.parse(
         textwrap.dedent("""
       f = lambda x: 1, lambda x: 2
     """))
@@ -227,7 +226,7 @@ class AstUtilTest(test.TestCase):
     self.assertLambdaNodes(nodes, ('(1)', '(2)'))
 
   def test_find_matching_definitions_lambda_uses_arg_names(self):
-    node = parser.parse_str(
+    node = parser.parse(
         textwrap.dedent("""
       f = lambda x: 1, lambda y: 2
     """))
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index ca3a8e55cf4..c2da09ef72b 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -38,7 +38,7 @@ from enum import Enum
 import gast
 # pylint:enable=g-bad-import-order
 
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
 
 
 class Node(object):
@@ -77,10 +77,9 @@ class Node(object):
     elif isinstance(self.ast_node, gast.ClassDef):
       return 'class %s' % self.ast_node.name
     elif isinstance(self.ast_node, gast.withitem):
-      return compiler.ast_to_source(
+      return parser.unparse(
           self.ast_node.context_expr, include_encoding_marker=False).strip()
-    return compiler.ast_to_source(
-        self.ast_node, include_encoding_marker=False).strip()
+    return parser.unparse(self.ast_node, include_encoding_marker=False).strip()
 
 
 class Graph(
diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
index f3bbba20925..e4a5a0accd5 100644
--- a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
@@ -22,7 +22,7 @@ import textwrap
 
 import gast
 
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import loader
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.autograph.pyct.common_transformers import anf
@@ -76,9 +76,9 @@ class AnfTestBase(test.TestCase):
     return transformer.Context(entity_info)
 
   def assert_same_ast(self, expected_node, node, msg=None):
-    expected_source = compiler.ast_to_source(expected_node, indentation='  ')
+    expected_source = parser.unparse(expected_node, indentation='  ')
     expected_str = textwrap.dedent(expected_source).strip()
-    got_source = compiler.ast_to_source(node, indentation='  ')
+    got_source = parser.unparse(node, indentation='  ')
     got_str = textwrap.dedent(got_source).strip()
     self.assertEqual(expected_str, got_str, msg=msg)
 
@@ -112,7 +112,7 @@ class AnfTransformerTest(AnfTestBase):
 
     node, _ = parser.parse_entity(test_function, future_features=())
     node = anf.transform(node, self._simple_context())
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     self.assertEqual(test_function(), result.test_function())
 
   def test_binop_basic(self):
@@ -463,13 +463,13 @@ class AnfNonTransformationTest(AnfTransformerTest):
     # syntax highlights nicely, but Python doesn't try to execute the
     # statements.
     node, _ = parser.parse_entity(test_fn, future_features=())
-    orig_source = compiler.ast_to_source(node, indentation='  ')
+    orig_source = parser.unparse(node, indentation='  ')
     orig_str = textwrap.dedent(orig_source).strip()
     config = [(anf.ANY, anf.LEAVE)]  # Configuration to trasform nothing
     node = anf.transform(
         node, self._simple_context(),
         config=config, gensym_source=DummyGensym)
-    new_source = compiler.ast_to_source(node, indentation='  ')
+    new_source = parser.unparse(node, indentation='  ')
     new_str = textwrap.dedent(new_source).strip()
     self.assertEqual(orig_str, new_str)
 
diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/loader.py
similarity index 55%
rename from tensorflow/python/autograph/pyct/compiler.py
rename to tensorflow/python/autograph/pyct/loader.py
index 297f28cfeaf..3690833b793 100644
--- a/tensorflow/python/autograph/pyct/compiler.py
+++ b/tensorflow/python/autograph/pyct/loader.py
@@ -29,64 +29,14 @@ import imp
 import os
 import tempfile
 
-import astor
-import gast
 import six
 
 from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.utils import ag_logging
 
 
-def ast_to_source(node, indentation='  ', include_encoding_marker=True):
-  """Return the source code of given AST.
-
-  Args:
-    node: The code to compile, as an AST object.
-    indentation: The string to use for indentation.
-    include_encoding_marker: Bool, thether to include a comment on the first
-      line to explicitly specify UTF-8 encoding.
-
-  Returns:
-    code: The source code generated from the AST object
-    source_mapping: A mapping between the user and AutoGraph generated code.
-  """
-  if not isinstance(node, (list, tuple)):
-    node = (node,)
-  generator = astor.code_gen.SourceGenerator(indentation, False,
-                                             astor.string_repr.pretty_string)
-
-  for n in node:
-    if isinstance(n, gast.AST):
-      n = gast.gast_to_ast(n)
-    generator.visit(n)
-    generator.result.append('\n')
-
-  # In some versions of Python, literals may appear as actual values. This
-  # ensures everything is string.
-  code = ''.join(map(str, generator.result))
-
-  # Strip leading blank lines.
-  code_lines = code.split('\n')
-  trimmed_code_lines = []
-  for l in code_lines:
-    if l.rstrip() or trimmed_code_lines:
-      trimmed_code_lines.append(l)
-  code = '\n'.join(trimmed_code_lines)
-
-  # Work around the reference cycle generated by astor.
-  # See https://github.com/berkerpeksag/astor/blob/55dd323f7d8d696610c703c0296763c567685c31/astor/code_gen.py#L162  # pylint:disable=line-too-long
-  # Reference cycles are quite disliked by TensorFlow's tests.
-  if hasattr(generator, 'write'):
-    generator.write = None
-  del generator
-
-  if include_encoding_marker:
-    code = '# coding=utf-8\n' + code
-
-  return code
-
-
-def source_to_entity(source, delete_on_exit):
+def load_source(source, delete_on_exit):
   """Loads the given source code as a Python module."""
   if six.PY2:
     source = source.encode('utf-8')
@@ -104,23 +54,22 @@ def source_to_entity(source, delete_on_exit):
   return imp.load_source(module_name, f.name), f.name
 
 
-# TODO(mdan): Rename: ast_to_entity
-def ast_to_object(nodes,
-                  indentation='  ',
-                  include_source_map=False,
-                  delete_on_exit=True):
-  """Return the Python objects represented by given AST.
+def load_ast(nodes,
+             indentation='  ',
+             include_source_map=False,
+             delete_on_exit=True):
+  """Loads the given AST as a Python module.
 
   Compiling the AST code this way ensures that the source code is readable by
   e.g. `pdb` or `inspect`.
 
   Args:
     nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST
-        object.
+      object.
     indentation: Text, the string to use for indentation.
     include_source_map: bool, whether return a source map.
     delete_on_exit: bool, whether to delete the temporary file used for
-        compilation on exit.
+      compilation on exit.
 
   Returns:
     Tuple[module, Text, Dict[LineLocation, OriginInfo]], containing:
@@ -131,8 +80,8 @@ def ast_to_object(nodes,
   if not isinstance(nodes, (list, tuple)):
     nodes = (nodes,)
 
-  source = ast_to_source(nodes, indentation=indentation)
-  module, _ = source_to_entity(source, delete_on_exit)
+  source = parser.unparse(nodes, indentation=indentation)
+  module, _ = load_source(source, delete_on_exit)
 
   if include_source_map:
     source_map = origin_info.create_source_map(nodes, source, module.__file__)
diff --git a/tensorflow/python/autograph/pyct/compiler_test.py b/tensorflow/python/autograph/pyct/loader_test.py
similarity index 71%
rename from tensorflow/python/autograph/pyct/compiler_test.py
rename to tensorflow/python/autograph/pyct/loader_test.py
index 3be0060612a..da7e336c5bc 100644
--- a/tensorflow/python/autograph/pyct/compiler_test.py
+++ b/tensorflow/python/autograph/pyct/loader_test.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for compiler module."""
+"""Tests for loader module."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -23,15 +23,15 @@ import textwrap
 
 import gast
 
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import loader
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.platform import test
 from tensorflow.python.util import tf_inspect
 
 
-class CompilerTest(test.TestCase):
+class LoaderTest(test.TestCase):
 
-  def test_parser_compile_identity(self):
+  def test_parse_load_identity(self):
 
     def test_fn(x):
       a = True
@@ -41,37 +41,13 @@ class CompilerTest(test.TestCase):
       return b
 
     node, _ = parser.parse_entity(test_fn, future_features=())
-    module, _, _ = compiler.ast_to_object(node)
+    module, _, _ = loader.load_ast(node)
 
     self.assertEqual(
         textwrap.dedent(tf_inspect.getsource(test_fn)),
         tf_inspect.getsource(module.test_fn))
 
-  def test_ast_to_source(self):
-    node = gast.If(
-        test=gast.Num(1),
-        body=[
-            gast.Assign(
-                targets=[gast.Name('a', gast.Store(), None)],
-                value=gast.Name('b', gast.Load(), None))
-        ],
-        orelse=[
-            gast.Assign(
-                targets=[gast.Name('a', gast.Store(), None)],
-                value=gast.Str('c'))
-        ])
-
-    source = compiler.ast_to_source(node, indentation='  ')
-    self.assertEqual(
-        textwrap.dedent("""
-            # coding=utf-8
-            if 1:
-              a = b
-            else:
-              a = 'c'
-        """).strip(), source.strip())
-
-  def test_ast_to_object(self):
+  def test_load_ast(self):
     node = gast.FunctionDef(
         name='f',
         args=gast.arguments(
@@ -91,7 +67,7 @@ class CompilerTest(test.TestCase):
         decorator_list=[],
         returns=None)
 
-    module, source, _ = compiler.ast_to_object(node)
+    module, source, _ = loader.load_ast(node)
 
     expected_source = """
       # coding=utf-8
@@ -107,14 +83,14 @@ class CompilerTest(test.TestCase):
           textwrap.dedent(expected_source).strip(),
           temp_output.read().strip())
 
-  def test_source_to_entity(self):
+  def test_load_source(self):
     test_source = textwrap.dedent(u"""
       # coding=utf-8
       def f(a):
         '日本語 Δθₜ ← Δθₜ₋₁ + ∇Q(sₜ, aₜ)(rₜ + γₜ₊₁ max Q(⋅))'
         return a + 1
     """)
-    module, _ = compiler.source_to_entity(test_source, delete_on_exit=True)
+    module, _ = loader.load_source(test_source, delete_on_exit=True)
     self.assertEqual(module.f(1), 2)
     self.assertEqual(
         module.f.__doc__, '日本語 Δθₜ ← Δθₜ₋₁ + ∇Q(sₜ, aₜ)(rₜ + γₜ₊₁ max Q(⋅))')
diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index 5479fefbb22..ae1d5e18334 100644
--- a/tensorflow/python/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -102,7 +102,7 @@ def create_source_map(nodes, code, filepath):
     Dict[LineLocation, OriginInfo], mapping locations in code to locations
     indicated by origin annotations in node.
   """
-  reparsed_nodes = parser.parse_str(code, preamble_len=0, single_node=False)
+  reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False)
   for node in reparsed_nodes:
     resolve(node, code, filepath, node.lineno, node.col_offset)
 
diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py
index 91c6ee5778f..01ded4cc559 100644
--- a/tensorflow/python/autograph/pyct/origin_info_test.py
+++ b/tensorflow/python/autograph/pyct/origin_info_test.py
@@ -39,7 +39,7 @@ class OriginInfoTest(test.TestCase):
     """
     source = textwrap.dedent(source)
 
-    node = parser.parse_str(source)
+    node = parser.parse(source)
     fake_origin = origin_info.OriginInfo(
         loc=origin_info.Location('fake_filename', 3, 7),
         function_name='fake_function_name',
@@ -118,7 +118,7 @@ class OriginInfoTest(test.TestCase):
         return x  # comment
     """
     source = textwrap.dedent(source)
-    node = parser.parse_str(source)
+    node = parser.parse(source)
     origin_info.resolve(node, source, 'test_file', 10, 10)
 
     def_origin = anno.getanno(node, anno.Basic.ORIGIN)
diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py
index c5b2fe5832a..1b745fa4219 100644
--- a/tensorflow/python/autograph/pyct/parser.py
+++ b/tensorflow/python/autograph/pyct/parser.py
@@ -25,6 +25,7 @@ import re
 import textwrap
 import tokenize
 
+import astor
 import gast
 import six
 
@@ -109,7 +110,7 @@ def dedent_block(code_string):
 
 
 def _attempt_to_parse_normal_source(source, future_features):
-  return parse_str(source, preamble_len=len(future_features)), source
+  return parse(source, preamble_len=len(future_features)), source
 
 
 def _attempt_to_parse_lambda_source(source, original_source,
@@ -131,17 +132,17 @@ def _attempt_to_parse_lambda_source(source, original_source,
     source: the processed source code of `entity`.
     original_source: the source code of `entity`, as it was reported
         by `inspect.getsource`.
-    future_features: see `parse_str`.
+    future_features: see `parse`.
     try_fallback: whether to attempt to remove extra code from `source` before
         one more attempt to parse it.
   Returns:
-    Same as `parse_str`.
+    Same as `parse`.
   """
 
   try:
-    return parse_str(source, preamble_len=len(future_features)), source
+    return parse(source, preamble_len=len(future_features)), source
 
-  # Note: the ValueError may be raised by parse_str.
+  # Note: the ValueError may be raised by parse.
   except (SyntaxError, ValueError) as e:
     def fail():
       raise errors.UnsupportedLanguageElementError(
@@ -209,7 +210,7 @@ def parse_entity(entity, future_features):
 
 
 # TODO(mdan): This should take futures as input instead.
-def parse_str(src, preamble_len=0, single_node=True):
+def parse(src, preamble_len=0, single_node=True):
   """Returns the AST of given piece of code.
 
   Args:
@@ -244,9 +245,58 @@ def parse_expression(src):
     ValueError: if src does not consist of a single Expression.
   """
   src = STANDARD_PREAMBLE + src.strip()
-  node = parse_str(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
+  node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
   if __debug__:
     if not isinstance(node, gast.Expr):
       raise ValueError(
           'expected a single expression, found instead {}'.format(node))
   return node.value
+
+
+def unparse(node, indentation='  ', include_encoding_marker=True):
+  """Returns the source code of given AST.
+
+  Args:
+    node: The code to compile, as an AST object.
+    indentation: The string to use for indentation.
+    include_encoding_marker: Bool, thether to include a comment on the first
+      line to explicitly specify UTF-8 encoding.
+
+  Returns:
+    code: The source code generated from the AST object
+    source_mapping: A mapping between the user and AutoGraph generated code.
+  """
+  if not isinstance(node, (list, tuple)):
+    node = (node,)
+  generator = astor.code_gen.SourceGenerator(indentation, False,
+                                             astor.string_repr.pretty_string)
+
+  for n in node:
+    if isinstance(n, gast.AST):
+      n = gast.gast_to_ast(n)
+    generator.visit(n)
+    generator.result.append('\n')
+
+  # In some versions of Python, literals may appear as actual values. This
+  # ensures everything is string.
+  code = ''.join(map(str, generator.result))
+
+  # Strip leading blank lines.
+  code_lines = code.split('\n')
+  trimmed_code_lines = []
+  for l in code_lines:
+    if l.rstrip() or trimmed_code_lines:
+      trimmed_code_lines.append(l)
+  code = '\n'.join(trimmed_code_lines)
+
+  # Work around the reference cycle generated by astor.
+  # See https://github.com/berkerpeksag/astor/blob/55dd323f7d8d696610c703c0296763c567685c31/astor/code_gen.py#L162  # pylint:disable=line-too-long
+  # Reference cycles are quite disliked by TensorFlow's tests.
+  if hasattr(generator, 'write'):
+    generator.write = None
+  del generator
+
+  if include_encoding_marker:
+    code = '# coding=utf-8\n' + code
+
+  return code
diff --git a/tensorflow/python/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py
index ef62d140525..f5c1dcb7021 100644
--- a/tensorflow/python/autograph/pyct/parser_test.py
+++ b/tensorflow/python/autograph/pyct/parser_test.py
@@ -18,6 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import textwrap
+
+import gast
+
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.platform import test
 
@@ -130,6 +134,30 @@ string""")
     self.assertEqual('a', node.value.id)
     self.assertEqual('b', node.attr)
 
+  def test_unparse(self):
+    node = gast.If(
+        test=gast.Num(1),
+        body=[
+            gast.Assign(
+                targets=[gast.Name('a', gast.Store(), None)],
+                value=gast.Name('b', gast.Load(), None))
+        ],
+        orelse=[
+            gast.Assign(
+                targets=[gast.Name('a', gast.Store(), None)],
+                value=gast.Str('c'))
+        ])
+
+    source = parser.unparse(node, indentation='  ')
+    self.assertEqual(
+        textwrap.dedent("""
+            # coding=utf-8
+            if 1:
+              a = b
+            else:
+              a = 'c'
+        """).strip(), source.strip())
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/autograph/pyct/qual_names_test.py b/tensorflow/python/autograph/pyct/qual_names_test.py
index 48db7bd7fe0..f32bf19e946 100644
--- a/tensorflow/python/autograph/pyct/qual_names_test.py
+++ b/tensorflow/python/autograph/pyct/qual_names_test.py
@@ -192,7 +192,7 @@ class QNResolverTest(test.TestCase):
       [f, (g.h.i)]
       j(k, l)
     """
-    nodes = parser.parse_str(textwrap.dedent(samples), single_node=False)
+    nodes = parser.parse(textwrap.dedent(samples), single_node=False)
     nodes = tuple(resolve(node).value for node in nodes)
 
     self.assertQNStringIs(nodes[0], 'a')
@@ -218,7 +218,7 @@ class QNResolverTest(test.TestCase):
       a.b[c[d]].e.f
       a.b[c[d.e.f].g].h
     """
-    nodes = parser.parse_str(textwrap.dedent(samples), single_node=False)
+    nodes = parser.parse(textwrap.dedent(samples), single_node=False)
     nodes = tuple(resolve(node).value for node in nodes)
 
     self.assertQNStringIs(nodes[0], 'x[i]')
@@ -241,7 +241,7 @@ class QNResolverTest(test.TestCase):
       z[i]()
       z()[i]
     """
-    nodes = parser.parse_str(textwrap.dedent(samples), single_node=False)
+    nodes = parser.parse(textwrap.dedent(samples), single_node=False)
     nodes = tuple(resolve(node).value for node in nodes)
     self.assertQNStringIs(nodes[0], 'a.b')
     self.assertQNStringIs(nodes[1].func, 'a.b')
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 253e2943a12..24d2a0760b9 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -260,7 +260,7 @@ def replace(template, **replacements):
   for k in replacements:
     replacements[k] = _convert_to_ast(replacements[k])
   template_str = parser.STANDARD_PREAMBLE + textwrap.dedent(template)
-  nodes = parser.parse_str(
+  nodes = parser.parse(
       template_str,
       preamble_len=parser.STANDARD_PREAMBLE_LEN,
       single_node=False)
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 5ed10d9c937..2085e555ff4 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -23,7 +23,7 @@ import imp
 from absl.testing import parameterized
 import gast
 
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import loader
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import qual_names as qn
 from tensorflow.python.autograph.pyct import templates
@@ -75,7 +75,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
     """
 
     node = templates.replace(template, b=('a', 'c'))[0]
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
 
     self.assertEqual((2, 3), result.test_fn(2, 3))
 
@@ -88,7 +88,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
     """
 
     node = templates.replace(template, a='b')[0]
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     self.assertEqual(7, result.test_fn(2))
 
   def test_replace_function_name(self):
@@ -100,7 +100,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
     """
 
     node = templates.replace(template, fname='test_fn')[0]
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     self.assertEqual(7, result.test_fn(2))
 
   def test_replace_code_block(self):
@@ -117,7 +117,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
                 gast.Name('a', None, None)
             ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
         ] * 2)[0]
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     self.assertEqual(3, result.test_fn(1))
 
   def test_replace_attribute(self):
@@ -127,7 +127,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
     """
 
     node = templates.replace(template, foo='b')[0]
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     mod = imp.new_module('test')
     mod.b = 3
     self.assertEqual(3, result.test_fn(mod))
@@ -217,7 +217,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
 
     source = parser.parse_expression('f(d=3, f=5)')
     node = templates.replace(template, kws=source.keywords)[0]
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     self.assertEqual(9, result.test_fn())
 
     with self.assertRaises(ValueError):
@@ -237,7 +237,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
 
     source = parser.parse_expression('f()(b)')
     node = templates.replace(template, foo=source)[0]
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     self.assertEqual(15, result.test_fn())
 
   def test_replace_name_with_dict(self):
@@ -248,7 +248,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
 
     source = parser.parse_expression('{\'bar\': 3}')
     node = templates.replace(template, foo=source)[0]
-    result, _, _ = compiler.ast_to_object(node)
+    result, _, _ = loader.load_ast(node)
     self.assertEqual(3, result.test_fn())
 
   def test_replace_as_expression(self):
diff --git a/tensorflow/python/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py
index 592ff0c45e6..ddc31737155 100644
--- a/tensorflow/python/autograph/pyct/transformer.py
+++ b/tensorflow/python/autograph/pyct/transformer.py
@@ -23,7 +23,7 @@ import collections
 import gast
 
 from tensorflow.python.autograph.pyct import anno
-from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import loader
 from tensorflow.python.autograph.pyct import pretty_printer
 from tensorflow.python.autograph.pyct import templates
 
@@ -301,7 +301,7 @@ class Base(gast.NodeTransformer):
   def debug_print_src(self, node):
     """Helper method useful for debugging. Prints the AST as code."""
     if __debug__:
-      print(compiler.ast_to_source(node))
+      print(loader.load_ast(node))
     return node
 
   def create_assignment(self, target, expression):
@@ -436,7 +436,7 @@ class Base(gast.NodeTransformer):
 
   def _get_source(self, node):
     try:
-      source, _ = compiler.ast_to_source(node)
+      source, _ = loader.load_ast(node)
       return source
     # pylint: disable=broad-except
     # This function is used for error reporting.  If an exception occurs here,

From cf03f5048cf4cdf41d7dbeb14bb2ae89488fcd53 Mon Sep 17 00:00:00 2001
From: Eugene Brevdo 
Date: Mon, 2 Dec 2019 14:08:20 -0800
Subject: [PATCH 173/279] [TF OSS] tf_proto_library also produces a proper
 proto_library rule of the same name.

PiperOrigin-RevId: 283414137
Change-Id: I7f88066e7c33888a27b49f5252e8884d1e852bf8
---
 .../core/platform/default/build_config.bzl    | 57 +++++++++++--------
 1 file changed, 32 insertions(+), 25 deletions(-)

diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 33e815c2a3f..a95de6632ce 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -10,6 +10,26 @@ load(
     "if_mkl_ml",
 )
 
+def well_known_proto_libs():
+    """Set of standard protobuf protos, like Any and Timestamp.
+
+    This list should be provided by protobuf.bzl, but it's not.
+    """
+    return [
+        "@com_google_protobuf//:any_proto",
+        "@com_google_protobuf//:api_proto",
+        "@com_google_protobuf//:compiler_plugin_proto",
+        "@com_google_protobuf//:descriptor_proto",
+        "@com_google_protobuf//:duration_proto",
+        "@com_google_protobuf//:empty_proto",
+        "@com_google_protobuf//:field_mask_proto",
+        "@com_google_protobuf//:source_context_proto",
+        "@com_google_protobuf//:struct_proto",
+        "@com_google_protobuf//:timestamp_proto",
+        "@com_google_protobuf//:type_proto",
+        "@com_google_protobuf//:wrappers_proto",
+    ]
+
 # Appends a suffix to a list of deps.
 def tf_deps(deps, suffix):
     tf_deps = []
@@ -259,18 +279,6 @@ def cc_proto_library(
         **kargs
     )
 
-    # Temporarily also add an alias with the 'protolib_name'. So far we relied
-    # on copybara to switch dependencies to the _cc dependencies. Now that these
-    # copybara rules are removed, we need to first change the internal BUILD
-    # files to depend on the correct targets instead, then this can be removed.
-    # TODO(b/143648532): Remove this once all reverse dependencies are migrated.
-    if protolib_name != name:
-        native.alias(
-            name = protolib_name,
-            actual = name,
-            visibility = kargs["visibility"],
-        )
-
 # Re-defined protocol buffer rule to bring in the change introduced in commit
 # https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68
 # which was not part of a stable protobuf release in 04/2018.
@@ -386,19 +394,6 @@ def tf_proto_library_cc(
             deps = [s + "_genproto" for s in protolib_deps],
         )
 
-        # Temporarily also add an alias with 'name'. So far we relied on
-        # copybara to switch dependencies to the _cc dependencies. Now that these
-        # copybara rules are removed, we need to change the internal BUILD files to
-        # depend on the correct targets instead.
-        # TODO(b/143648532): Remove this once all reverse dependencies are
-        # migrated.
-        native.alias(
-            name = name,
-            actual = cc_name,
-            testonly = testonly,
-            visibility = visibility,
-        )
-
         native.alias(
             name = cc_name + "_headers_only",
             actual = cc_name,
@@ -504,8 +499,20 @@ def tf_proto_library(
         make_default_target_header_only = False,
         exports = []):
     """Make a proto library, possibly depending on other proto libraries."""
+
+    # TODO(b/145545130): Add docstring explaining what rules this creates and how
+    # opensource projects importing TF in bazel can use them safely (i.e. w/o ODR or
+    # ABI violations).
     _ignore = (js_codegen, exports)
 
+    native.proto_library(
+        name = name,
+        srcs = srcs,
+        deps = protodeps + well_known_proto_libs(),
+        visibility = visibility,
+        testonly = testonly,
+    )
+
     tf_proto_library_cc(
         name = name,
         testonly = testonly,

From e439e8713300f8ad0468ce70ccb83474a8a294c8 Mon Sep 17 00:00:00 2001
From: Renjie Liu 
Date: Mon, 2 Dec 2019 14:10:42 -0800
Subject: [PATCH 174/279] Change the optimized path mean 4d to use integer only
 & also fix a dumb error in the reference path.

PiperOrigin-RevId: 283414594
Change-Id: Ibd7f6d35a2bf61734ff96474a0d4851e61fc1b1c
---
 .../internal/optimized/optimized_ops.h        | 328 ++++++++++--------
 .../internal/reference/reference_ops.h        |   2 +-
 tensorflow/lite/kernels/reduce.cc             |  23 +-
 tensorflow/lite/kernels/reduce_test.cc        |  17 +-
 4 files changed, 207 insertions(+), 163 deletions(-)

diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index e478fb87720..f8fc6113e61 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -195,6 +195,71 @@ MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data,
   return MatrixMap(data, rows, cols);
 }
 
+// TODO(renjieliu): Refactor this to merge with other
+// MultiplyByQuantizedMultipler.
+#ifdef USE_NEON
+inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
+    int32x4x4_t input_val, int32 quantized_multiplier, int shift) {
+  using gemmlowp::RoundingDivideByPOT;
+  using gemmlowp::SaturatingRoundingDoublingHighMul;
+  const int left_shift = shift > 0 ? shift : 0;
+  const int right_shift = shift > 0 ? 0 : -shift;
+  int32x4x4_t result;
+  // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp
+  // is limited to NEON.
+#ifdef GEMMLOWP_NEON
+  const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift);
+  result.val[0] =
+      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
+                              vmulq_s32(input_val.val[0], left_shifted_one_dup),
+                              quantized_multiplier),
+                          right_shift);
+  result.val[1] =
+      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
+                              vmulq_s32(input_val.val[1], left_shifted_one_dup),
+                              quantized_multiplier),
+                          right_shift);
+  result.val[2] =
+      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
+                              vmulq_s32(input_val.val[2], left_shifted_one_dup),
+                              quantized_multiplier),
+                          right_shift);
+  result.val[3] =
+      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
+                              vmulq_s32(input_val.val[3], left_shifted_one_dup),
+                              quantized_multiplier),
+                          right_shift);
+#else
+  for (int i = 0; i < 4; ++i) {
+    int32_t vals[4];
+    vals[0] = RoundingDivideByPOT(
+        SaturatingRoundingDoublingHighMul(
+            vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift),
+            quantized_multiplier),
+        right_shift);
+    vals[1] = RoundingDivideByPOT(
+        SaturatingRoundingDoublingHighMul(
+            vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift),
+            quantized_multiplier),
+        right_shift);
+    vals[2] = RoundingDivideByPOT(
+        SaturatingRoundingDoublingHighMul(
+            vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift),
+            quantized_multiplier),
+        right_shift);
+    vals[3] = RoundingDivideByPOT(
+        SaturatingRoundingDoublingHighMul(
+            vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift),
+            quantized_multiplier),
+        right_shift);
+
+    result.val[i] = vld1q_s32(reinterpret_cast(&vals));
+  }
+#endif
+  return result;
+}
+#endif
+
 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
                                              float output_activation_max,
                                              const RuntimeShape& bias_shape,
@@ -849,9 +914,8 @@ inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) {
 
 inline void MeanImpl(const tflite::MeanParams& op_params,
                      const RuntimeShape& input_shape, const uint8_t* input_data,
-                     int32 input_zero_point, float input_scale,
+                     int32 multiplier, int32 shift, int32 bias,
                      const RuntimeShape& output_shape, uint8_t* output_data,
-                     int32 output_zero_point, float output_scale,
                      int start_depth, int end_depth) {
   gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8/MeanImpl");
 
@@ -862,7 +926,6 @@ inline void MeanImpl(const tflite::MeanParams& op_params,
   const int output_width = output_shape.Dims(2);
   const int input_height = input_shape.Dims(1);
   const int input_width = input_shape.Dims(2);
-  const float num_elements_in_axis = input_width * input_height;
 
   TFLITE_CHECK_EQ(op_params.axis_count, 2);
   TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
@@ -870,83 +933,103 @@ inline void MeanImpl(const tflite::MeanParams& op_params,
   TFLITE_CHECK_EQ(output_height, 1);
   TFLITE_CHECK_EQ(output_width, 1);
 
-  const bool ordinary_mean =
-      (input_zero_point == output_zero_point && input_scale == output_scale);
-  float scale = 0.0f, bias = 0.0f;
-  if (!ordinary_mean) {
-    scale = input_scale / output_scale;
-    bias = -input_zero_point * scale + 0.5;
-  }
+  constexpr int32_t kMinValue = std::numeric_limits::min();
+  constexpr int32_t kMaxValue = std::numeric_limits::max();
 
 #ifdef USE_NEON
-  const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis);
-  // This is only an approximation as NEON does not offer division instruction.
-  const float32x4_t scale_dup = vdupq_n_f32(scale);
-  const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup);
-  float32x4_t zero_point_with_bias_dup = vdupq_n_f32(output_zero_point + bias);
+  const int32x4_t bias_dup = vdupq_n_s32(bias);
+  const int32x4_t min_dup = vdupq_n_s32(kMinValue);
+  const int32x4_t max_dup = vdupq_n_s32(kMaxValue);
 #endif  // USE_NEON
 
   for (int out_b = 0; out_b < output_batch; ++out_b) {
     int out_d = start_depth;
 #ifdef USE_NEON
 
-    for (; out_d < end_depth - 8; out_d += 8) {
-      float32x4_t temp_sum_1 = vdupq_n_f32(0);
-      float32x4_t temp_sum_2 = vdupq_n_f32(0);
+    for (; out_d <= end_depth - 16; out_d += 16) {
+      int32x4x4_t temp_sum;
+      temp_sum.val[0] = vdupq_n_s32(0);
+      temp_sum.val[1] = vdupq_n_s32(0);
+      temp_sum.val[2] = vdupq_n_s32(0);
+      temp_sum.val[3] = vdupq_n_s32(0);
       for (int in_h = 0; in_h < input_height; ++in_h) {
         for (int in_w = 0; in_w < input_width; ++in_w) {
           const uint8_t* input_data_ptr =
               input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
-          uint8x8_t input_data_val = vld1_u8(input_data_ptr);
-          int16x8_t input_data_val_shift =
-              vreinterpretq_s16_u16(vmovl_u8(input_data_val));
-          float32x4_t input_float_1 =
-              vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift)));
-          float32x4_t input_float_2 =
-              vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift)));
-          temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1);
-          temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2);
+          uint8x16_t input_data_val = vld1q_u8(input_data_ptr);
+
+          int16x8_t input_data_low_shift =
+              vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_data_val)));
+          int16x8_t input_data_high_shift =
+              vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_data_val)));
+
+          int32x4_t input_low_low =
+              vmovl_s16(vget_low_s16(input_data_low_shift));
+          int32x4_t input_high_low =
+              vmovl_s16(vget_high_s16(input_data_low_shift));
+          int32x4_t input_low_high =
+              vmovl_s16(vget_low_s16(input_data_high_shift));
+          int32x4_t input_high_high =
+              vmovl_s16(vget_high_s16(input_data_high_shift));
+
+          temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low);
+          temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low);
+          temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high);
+          temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high);
         }
       }
 
-      const float32x4_t mean_1 =
-          DivideSumForMeanImpl(temp_sum_1, num_elements_reverse, ordinary_mean,
-                               scale_dup, zero_point_with_bias_dup);
-      const float32x4_t mean_2 =
-          DivideSumForMeanImpl(temp_sum_2, num_elements_reverse, ordinary_mean,
-                               scale_dup, zero_point_with_bias_dup);
+      temp_sum =
+          MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift);
+
+      temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup);
+      temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup);
+      temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup);
+      temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup);
+
+      temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup);
+      temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup);
+      temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup);
+      temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup);
+
+      uint16x4_t narrowed_low_low =
+          vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[0]));
+      uint16x4_t narrowed_high_low =
+          vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[1]));
+      uint16x4_t narrowed_low_high =
+          vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[2]));
+      uint16x4_t narrowed_high_high =
+          vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[3]));
+
+      uint16x8_t combined_low =
+          vcombine_s16(narrowed_low_low, narrowed_high_low);
+      uint16x8_t combined_high =
+          vcombine_s16(narrowed_low_high, narrowed_high_high);
+
+      uint8x8_t narrowed_low = vmovn_u16(combined_low);
+      uint8x8_t narrowed_high = vmovn_u16(combined_high);
+
+      uint8x16_t combined_output = vcombine_s8(narrowed_low, narrowed_high);
 
-      uint32x4_t casted_mean_1 = RoundToNearestUnsigned(mean_1);
-      uint16x4_t narrow_range_mean_1 = vmovn_u32(casted_mean_1);
-      uint32x4_t casted_mean_2 = RoundToNearestUnsigned(mean_2);
-      uint16x4_t narrow_range_mean_2 = vmovn_u32(casted_mean_2);
-      uint16x8_t combined_mean =
-          vcombine_u16(narrow_range_mean_2, narrow_range_mean_1);
-      uint8x8_t narrowed_combined_mean = vmovn_u16(combined_mean);
       uint8_t* output_data_ptr =
           output_data + Offset(output_shape, out_b, 0, 0, out_d);
-      vst1_u8(output_data_ptr, narrowed_combined_mean);
+      vst1q_u8(output_data_ptr, combined_output);
     }
 #endif  // USE_NEON
 
     for (; out_d < end_depth; ++out_d) {
-      float temp_value = 0;
+      int acc = 0;
       for (int in_h = 0; in_h < input_height; ++in_h) {
         for (int in_w = 0; in_w < input_width; ++in_w) {
-          temp_value +=
-              input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
+          acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
         }
       }
 
-      temp_value = temp_value / num_elements_in_axis;
-      if (ordinary_mean) {
-        output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
-            static_cast(round(temp_value));
-      } else {
-        output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
-            static_cast(round(temp_value * scale + bias)) +
-            output_zero_point;
-      }
+      acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
+      acc += bias;
+      acc = std::min(std::max(acc, kMinValue), kMaxValue);
+      output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
+          static_cast(acc);
     }
   }
 }
@@ -954,40 +1037,36 @@ inline void MeanImpl(const tflite::MeanParams& op_params,
 struct MeanWorkerTask : cpu_backend_threadpool::Task {
   MeanWorkerTask(const tflite::MeanParams& op_params,
                  const RuntimeShape& input_shape, const uint8_t* input_data,
-                 int32 input_zero_point, float input_scale,
+                 int32 multiplier, int32 shift, int32 bias,
                  const RuntimeShape& output_shape, uint8_t* output_data,
-                 int32 output_zero_point, float output_scale, int start_height,
-                 int end_height)
-      : op_params_(op_params),
-        input_shape_(input_shape),
-        input_data_(input_data),
-        input_zero_point_(input_zero_point),
-        input_scale_(input_scale),
-        output_shape_(output_shape),
-        output_data_(output_data),
-        output_zero_point_(output_zero_point),
-        output_scale_(output_scale),
-        start_height_(start_height),
-        end_height_(end_height) {}
+                 int start_height, int end_height)
+      : op_params(op_params),
+        input_shape(input_shape),
+        input_data(input_data),
+        multiplier(multiplier),
+        shift(shift),
+        bias(bias),
+        output_shape(output_shape),
+        output_data(output_data),
+        start_height(start_height),
+        end_height(end_height) {}
 
   void Run() override {
-    MeanImpl(op_params_, input_shape_, input_data_, input_zero_point_,
-             input_scale_, output_shape_, output_data_, output_zero_point_,
-             output_scale_, start_height_, end_height_);
+    MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
+             output_shape, output_data, start_height, end_height);
   }
 
  private:
-  const tflite::MeanParams& op_params_;
-  const RuntimeShape& input_shape_;
-  const uint8_t* input_data_;
-  int32 input_zero_point_;
-  float input_scale_;
-  const RuntimeShape& output_shape_;
-  uint8_t* output_data_;
-  int32 output_zero_point_;
-  float output_scale_;
-  int start_height_;
-  int end_height_;
+  const tflite::MeanParams& op_params;
+  const RuntimeShape& input_shape;
+  const uint8_t* input_data;
+  int32 multiplier;
+  int32 shift;
+  int32 bias;
+  const RuntimeShape& output_shape;
+  uint8_t* output_data;
+  int start_height;
+  int end_height;
 };
 
 inline void Mean(const tflite::MeanParams& op_params,
@@ -1015,6 +1094,18 @@ inline void Mean(const tflite::MeanParams& op_params,
   TFLITE_CHECK_EQ(output_height, 1);
   TFLITE_CHECK_EQ(output_width, 1);
 
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const float num_elements_in_axis = input_width * input_height;
+
+  int32 bias =
+      output_zero_point -
+      static_cast(input_zero_point * input_scale / output_scale);
+  float real_scale = input_scale / (num_elements_in_axis * output_scale);
+
+  int32 multiplier, shift;
+  QuantizeMultiplier(real_scale, &multiplier, &shift);
+
   constexpr int kMinDepthPerThread = 8;
   int thread_count = output_depth / kMinDepthPerThread;
   thread_count = thread_count > 0 ? thread_count : 1;
@@ -1022,9 +1113,8 @@ inline void Mean(const tflite::MeanParams& op_params,
       std::min(thread_count, cpu_backend_context->max_num_threads());
 
   if (capped_thread_count == 1) {
-    MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
-             output_shape, output_data, output_zero_point, output_scale, 0,
-             output_depth);
+    MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
+             output_shape, output_data, 0, output_depth);
   } else {
     // Instead parrallel for batch, we loop for the output_depth since batch
     // is typical 1.
@@ -1037,9 +1127,8 @@ inline void Mean(const tflite::MeanParams& op_params,
       // Try to distribute the tasks as even as possible.
       int depth_end = depth_start +
                       (output_depth - depth_start) / (capped_thread_count - i);
-      tasks.emplace_back(op_params, input_shape, input_data, input_zero_point,
-                         input_scale, output_shape, output_data,
-                         output_zero_point, output_scale, depth_start,
+      tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift,
+                         bias, output_shape, output_data, depth_start,
                          depth_end);
       depth_start = depth_end;
     }
@@ -5465,71 +5554,6 @@ inline void TransposeConvV2(
   }
 }
 
-// TODO(renjieliu): Refactor this to merge with other
-// MultiplyByQuantizedMultipler.
-#ifdef USE_NEON
-inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
-    int32x4x4_t input_val, int32 quantized_multiplier, int shift) {
-  using gemmlowp::RoundingDivideByPOT;
-  using gemmlowp::SaturatingRoundingDoublingHighMul;
-  const int left_shift = shift > 0 ? shift : 0;
-  const int right_shift = shift > 0 ? 0 : -shift;
-  int32x4x4_t result;
-  // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp
-  // is limited to NEON.
-#ifdef GEMMLOWP_NEON
-  const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift);
-  result.val[0] =
-      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                              vmulq_s32(input_val.val[0], left_shifted_one_dup),
-                              quantized_multiplier),
-                          right_shift);
-  result.val[1] =
-      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                              vmulq_s32(input_val.val[1], left_shifted_one_dup),
-                              quantized_multiplier),
-                          right_shift);
-  result.val[2] =
-      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                              vmulq_s32(input_val.val[2], left_shifted_one_dup),
-                              quantized_multiplier),
-                          right_shift);
-  result.val[3] =
-      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                              vmulq_s32(input_val.val[3], left_shifted_one_dup),
-                              quantized_multiplier),
-                          right_shift);
-#else
-  for (int i = 0; i < 4; ++i) {
-    int32_t vals[4];
-    vals[0] = RoundingDivideByPOT(
-        SaturatingRoundingDoublingHighMul(
-            vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift),
-            quantized_multiplier),
-        right_shift);
-    vals[1] = RoundingDivideByPOT(
-        SaturatingRoundingDoublingHighMul(
-            vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift),
-            quantized_multiplier),
-        right_shift);
-    vals[2] = RoundingDivideByPOT(
-        SaturatingRoundingDoublingHighMul(
-            vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift),
-            quantized_multiplier),
-        right_shift);
-    vals[3] = RoundingDivideByPOT(
-        SaturatingRoundingDoublingHighMul(
-            vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift),
-            quantized_multiplier),
-        right_shift);
-
-    result.val[i] = vld1q_s32(reinterpret_cast(&vals));
-  }
-#endif
-  return result;
-}
-#endif
-
 inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
                      int32_t output_zp, int32_t* scratch, uint8_t* output) {
   gemmlowp::ScopedProfilingLabel label("Quantize/uint8");
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index 502598c27f5..53b2049d74a 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -2622,7 +2622,7 @@ inline void Mean(const tflite::MeanParams& op_params,
           acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
         }
       }
-      MultiplyByQuantizedMultiplier(acc, multiplier, shift);
+      acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
       acc += bias;
       acc = std::min(std::max(acc, kMinValue), kMaxValue);
       output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc
index 5685d4c4ff9..7c412334ab1 100644
--- a/tensorflow/lite/kernels/reduce.cc
+++ b/tensorflow/lite/kernels/reduce.cc
@@ -445,9 +445,25 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
     case kTfLiteUInt8: {
       // TODO(b/139102329): Handle all the cases in the combined reference
       // method.
-      if (op_context.input->params.zero_point ==
-              op_context.output->params.zero_point &&
-          op_context.input->params.scale == op_context.output->params.scale) {
+      tflite::MeanParams op_params;
+      op_params.axis_count = num_axis;
+      ResolveAxis(GetTensorData(op_context.axis), num_axis, &op_params);
+      if (op_context.params->keep_dims &&
+          NumDimensions(op_context.input) == 4 && op_params.axis_count == 2 &&
+          ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+           (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
+        reference_ops::Mean(op_params, GetTensorShape(op_context.input),
+                            GetTensorData(op_context.input),
+                            op_context.input->params.zero_point,
+                            op_context.input->params.scale,
+                            GetTensorShape(op_context.output),
+                            GetTensorData(op_context.output),
+                            op_context.output->params.zero_point,
+                            op_context.output->params.scale);
+      } else if (op_context.input->params.zero_point ==
+                     op_context.output->params.zero_point &&
+                 op_context.input->params.scale ==
+                     op_context.output->params.scale) {
         TF_LITE_ENSURE(
             context,
             reference_ops::Mean(
@@ -726,6 +742,7 @@ TfLiteRegistration* Register_MEAN() {
   return Register_MEAN_REF();
 #endif
 }
+
 TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
 TfLiteRegistration* Register_REDUCE_PROD() {
   return Register_REDUCE_PROD_REF();
diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc
index 87c178fc673..12b94e2019c 100644
--- a/tensorflow/lite/kernels/reduce_test.cc
+++ b/tensorflow/lite/kernels/reduce_test.cc
@@ -287,18 +287,21 @@ TEST(ConstFloatMeanOpTest, KeepDims4DMeanUInt8) {
 
 TEST(ConstFloatMeanOpTest, KeepDims4DMeanLargeDepthUInt8) {
   float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
-  std::vector data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
-                             0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9,
-                             0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3,
-                             0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
-  MeanOpConstModel m({TensorType_UINT8, {1, 2, 2, 9}, -1.0, 1.0},
+  std::vector data = {
+      0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, 0.1, 0.1, 0.1, 0.4, 0.2, 0.2,
+      0.2, 0.9, 0.9, 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, 0.3, 0.1, 0.2,
+      0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
+      0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, 0.9, 0.9, 0.2, 0.3, 0.7, 0.7,
+      0.1, 0.1, 0.3, 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
+  MeanOpConstModel m({TensorType_UINT8, {1, 2, 2, 18}, -1.0, 1.0},
                      {TensorType_UINT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true);
   m.QuantizeAndPopulate(m.Input(), data);
   m.Invoke();
-  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9}));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 18}));
   EXPECT_THAT(m.GetDequantizedOutput(),
               ElementsAreArray(ArrayFloatNear(
-                  {0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425},
+                  {0.5, 0.55, 0.25, 0.35, 0.45, 0.5, 0.25, 0.3, 0.2, 0.2, 0.1,
+                   0.15, 0.35, 0.3, 0.15, 0.2, 0.6, 0.65},
                   kQuantizedTolerance)));
 }
 

From 22abc2772ab2d399bda8b122dae2fef99b62c29c Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Mon, 2 Dec 2019 14:21:42 -0800
Subject: [PATCH 175/279] [spirv] NFC: reorder sections in SPIRVBase.td

Put extensions and capabilities at the very beginning because
they will be referenced later by other definitions.

PiperOrigin-RevId: 283416972
Change-Id: I98255e327baf7de7f3debcc017551b0e07bd64f7
---
 .../include/mlir/Dialect/SPIRV/SPIRVBase.td   | 756 +++++++++---------
 .../mlir/utils/spirv/gen_spirv_dialect.py     |   5 +-
 2 files changed, 382 insertions(+), 379 deletions(-)

diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 07cdd7ac790..e1897a9e295 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -54,236 +54,6 @@ def SPV_Dialect : Dialect {
   let cppNamespace = "spirv";
 }
 
-//===----------------------------------------------------------------------===//
-// SPIR-V opcode specification
-//===----------------------------------------------------------------------===//
-
-class SPV_OpCode {
-  // Name used as reference to retrieve the opcode
-  string opname = name;
-
-  // Opcode associated with the name
-  int opcode = val;
-}
-
-// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
-
-def SPV_OC_OpNop                    : I32EnumAttrCase<"OpNop", 0>;
-def SPV_OC_OpUndef                  : I32EnumAttrCase<"OpUndef", 1>;
-def SPV_OC_OpSourceContinued        : I32EnumAttrCase<"OpSourceContinued", 2>;
-def SPV_OC_OpSource                 : I32EnumAttrCase<"OpSource", 3>;
-def SPV_OC_OpSourceExtension        : I32EnumAttrCase<"OpSourceExtension", 4>;
-def SPV_OC_OpName                   : I32EnumAttrCase<"OpName", 5>;
-def SPV_OC_OpMemberName             : I32EnumAttrCase<"OpMemberName", 6>;
-def SPV_OC_OpString                 : I32EnumAttrCase<"OpString", 7>;
-def SPV_OC_OpExtension              : I32EnumAttrCase<"OpExtension", 10>;
-def SPV_OC_OpExtInstImport          : I32EnumAttrCase<"OpExtInstImport", 11>;
-def SPV_OC_OpExtInst                : I32EnumAttrCase<"OpExtInst", 12>;
-def SPV_OC_OpMemoryModel            : I32EnumAttrCase<"OpMemoryModel", 14>;
-def SPV_OC_OpEntryPoint             : I32EnumAttrCase<"OpEntryPoint", 15>;
-def SPV_OC_OpExecutionMode          : I32EnumAttrCase<"OpExecutionMode", 16>;
-def SPV_OC_OpCapability             : I32EnumAttrCase<"OpCapability", 17>;
-def SPV_OC_OpTypeVoid               : I32EnumAttrCase<"OpTypeVoid", 19>;
-def SPV_OC_OpTypeBool               : I32EnumAttrCase<"OpTypeBool", 20>;
-def SPV_OC_OpTypeInt                : I32EnumAttrCase<"OpTypeInt", 21>;
-def SPV_OC_OpTypeFloat              : I32EnumAttrCase<"OpTypeFloat", 22>;
-def SPV_OC_OpTypeVector             : I32EnumAttrCase<"OpTypeVector", 23>;
-def SPV_OC_OpTypeArray              : I32EnumAttrCase<"OpTypeArray", 28>;
-def SPV_OC_OpTypeRuntimeArray       : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
-def SPV_OC_OpTypeStruct             : I32EnumAttrCase<"OpTypeStruct", 30>;
-def SPV_OC_OpTypePointer            : I32EnumAttrCase<"OpTypePointer", 32>;
-def SPV_OC_OpTypeFunction           : I32EnumAttrCase<"OpTypeFunction", 33>;
-def SPV_OC_OpConstantTrue           : I32EnumAttrCase<"OpConstantTrue", 41>;
-def SPV_OC_OpConstantFalse          : I32EnumAttrCase<"OpConstantFalse", 42>;
-def SPV_OC_OpConstant               : I32EnumAttrCase<"OpConstant", 43>;
-def SPV_OC_OpConstantComposite      : I32EnumAttrCase<"OpConstantComposite", 44>;
-def SPV_OC_OpConstantNull           : I32EnumAttrCase<"OpConstantNull", 46>;
-def SPV_OC_OpSpecConstantTrue       : I32EnumAttrCase<"OpSpecConstantTrue", 48>;
-def SPV_OC_OpSpecConstantFalse      : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
-def SPV_OC_OpSpecConstant           : I32EnumAttrCase<"OpSpecConstant", 50>;
-def SPV_OC_OpSpecConstantComposite  : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
-def SPV_OC_OpFunction               : I32EnumAttrCase<"OpFunction", 54>;
-def SPV_OC_OpFunctionParameter      : I32EnumAttrCase<"OpFunctionParameter", 55>;
-def SPV_OC_OpFunctionEnd            : I32EnumAttrCase<"OpFunctionEnd", 56>;
-def SPV_OC_OpFunctionCall           : I32EnumAttrCase<"OpFunctionCall", 57>;
-def SPV_OC_OpVariable               : I32EnumAttrCase<"OpVariable", 59>;
-def SPV_OC_OpLoad                   : I32EnumAttrCase<"OpLoad", 61>;
-def SPV_OC_OpStore                  : I32EnumAttrCase<"OpStore", 62>;
-def SPV_OC_OpAccessChain            : I32EnumAttrCase<"OpAccessChain", 65>;
-def SPV_OC_OpDecorate               : I32EnumAttrCase<"OpDecorate", 71>;
-def SPV_OC_OpMemberDecorate         : I32EnumAttrCase<"OpMemberDecorate", 72>;
-def SPV_OC_OpCompositeExtract       : I32EnumAttrCase<"OpCompositeExtract", 81>;
-def SPV_OC_OpConvertFToU            : I32EnumAttrCase<"OpConvertFToU", 109>;
-def SPV_OC_OpConvertFToS            : I32EnumAttrCase<"OpConvertFToS", 110>;
-def SPV_OC_OpConvertSToF            : I32EnumAttrCase<"OpConvertSToF", 111>;
-def SPV_OC_OpConvertUToF            : I32EnumAttrCase<"OpConvertUToF", 112>;
-def SPV_OC_OpUConvert               : I32EnumAttrCase<"OpUConvert", 113>;
-def SPV_OC_OpSConvert               : I32EnumAttrCase<"OpSConvert", 114>;
-def SPV_OC_OpFConvert               : I32EnumAttrCase<"OpFConvert", 115>;
-def SPV_OC_OpBitcast                : I32EnumAttrCase<"OpBitcast", 124>;
-def SPV_OC_OpFNegate                : I32EnumAttrCase<"OpFNegate", 127>;
-def SPV_OC_OpIAdd                   : I32EnumAttrCase<"OpIAdd", 128>;
-def SPV_OC_OpFAdd                   : I32EnumAttrCase<"OpFAdd", 129>;
-def SPV_OC_OpISub                   : I32EnumAttrCase<"OpISub", 130>;
-def SPV_OC_OpFSub                   : I32EnumAttrCase<"OpFSub", 131>;
-def SPV_OC_OpIMul                   : I32EnumAttrCase<"OpIMul", 132>;
-def SPV_OC_OpFMul                   : I32EnumAttrCase<"OpFMul", 133>;
-def SPV_OC_OpUDiv                   : I32EnumAttrCase<"OpUDiv", 134>;
-def SPV_OC_OpSDiv                   : I32EnumAttrCase<"OpSDiv", 135>;
-def SPV_OC_OpFDiv                   : I32EnumAttrCase<"OpFDiv", 136>;
-def SPV_OC_OpUMod                   : I32EnumAttrCase<"OpUMod", 137>;
-def SPV_OC_OpSRem                   : I32EnumAttrCase<"OpSRem", 138>;
-def SPV_OC_OpSMod                   : I32EnumAttrCase<"OpSMod", 139>;
-def SPV_OC_OpFRem                   : I32EnumAttrCase<"OpFRem", 140>;
-def SPV_OC_OpFMod                   : I32EnumAttrCase<"OpFMod", 141>;
-def SPV_OC_OpLogicalEqual           : I32EnumAttrCase<"OpLogicalEqual", 164>;
-def SPV_OC_OpLogicalNotEqual        : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
-def SPV_OC_OpLogicalOr              : I32EnumAttrCase<"OpLogicalOr", 166>;
-def SPV_OC_OpLogicalAnd             : I32EnumAttrCase<"OpLogicalAnd", 167>;
-def SPV_OC_OpLogicalNot             : I32EnumAttrCase<"OpLogicalNot", 168>;
-def SPV_OC_OpSelect                 : I32EnumAttrCase<"OpSelect", 169>;
-def SPV_OC_OpIEqual                 : I32EnumAttrCase<"OpIEqual", 170>;
-def SPV_OC_OpINotEqual              : I32EnumAttrCase<"OpINotEqual", 171>;
-def SPV_OC_OpUGreaterThan           : I32EnumAttrCase<"OpUGreaterThan", 172>;
-def SPV_OC_OpSGreaterThan           : I32EnumAttrCase<"OpSGreaterThan", 173>;
-def SPV_OC_OpUGreaterThanEqual      : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
-def SPV_OC_OpSGreaterThanEqual      : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
-def SPV_OC_OpULessThan              : I32EnumAttrCase<"OpULessThan", 176>;
-def SPV_OC_OpSLessThan              : I32EnumAttrCase<"OpSLessThan", 177>;
-def SPV_OC_OpULessThanEqual         : I32EnumAttrCase<"OpULessThanEqual", 178>;
-def SPV_OC_OpSLessThanEqual         : I32EnumAttrCase<"OpSLessThanEqual", 179>;
-def SPV_OC_OpFOrdEqual              : I32EnumAttrCase<"OpFOrdEqual", 180>;
-def SPV_OC_OpFUnordEqual            : I32EnumAttrCase<"OpFUnordEqual", 181>;
-def SPV_OC_OpFOrdNotEqual           : I32EnumAttrCase<"OpFOrdNotEqual", 182>;
-def SPV_OC_OpFUnordNotEqual         : I32EnumAttrCase<"OpFUnordNotEqual", 183>;
-def SPV_OC_OpFOrdLessThan           : I32EnumAttrCase<"OpFOrdLessThan", 184>;
-def SPV_OC_OpFUnordLessThan         : I32EnumAttrCase<"OpFUnordLessThan", 185>;
-def SPV_OC_OpFOrdGreaterThan        : I32EnumAttrCase<"OpFOrdGreaterThan", 186>;
-def SPV_OC_OpFUnordGreaterThan      : I32EnumAttrCase<"OpFUnordGreaterThan", 187>;
-def SPV_OC_OpFOrdLessThanEqual      : I32EnumAttrCase<"OpFOrdLessThanEqual", 188>;
-def SPV_OC_OpFUnordLessThanEqual    : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>;
-def SPV_OC_OpFOrdGreaterThanEqual   : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>;
-def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>;
-def SPV_OC_OpShiftRightLogical      : I32EnumAttrCase<"OpShiftRightLogical", 194>;
-def SPV_OC_OpShiftRightArithmetic   : I32EnumAttrCase<"OpShiftRightArithmetic", 195>;
-def SPV_OC_OpShiftLeftLogical       : I32EnumAttrCase<"OpShiftLeftLogical", 196>;
-def SPV_OC_OpBitwiseOr              : I32EnumAttrCase<"OpBitwiseOr", 197>;
-def SPV_OC_OpBitwiseXor             : I32EnumAttrCase<"OpBitwiseXor", 198>;
-def SPV_OC_OpBitwiseAnd             : I32EnumAttrCase<"OpBitwiseAnd", 199>;
-def SPV_OC_OpNot                    : I32EnumAttrCase<"OpNot", 200>;
-def SPV_OC_OpBitFieldInsert         : I32EnumAttrCase<"OpBitFieldInsert", 201>;
-def SPV_OC_OpBitFieldSExtract       : I32EnumAttrCase<"OpBitFieldSExtract", 202>;
-def SPV_OC_OpBitFieldUExtract       : I32EnumAttrCase<"OpBitFieldUExtract", 203>;
-def SPV_OC_OpBitReverse             : I32EnumAttrCase<"OpBitReverse", 204>;
-def SPV_OC_OpBitCount               : I32EnumAttrCase<"OpBitCount", 205>;
-def SPV_OC_OpControlBarrier         : I32EnumAttrCase<"OpControlBarrier", 224>;
-def SPV_OC_OpMemoryBarrier          : I32EnumAttrCase<"OpMemoryBarrier", 225>;
-def SPV_OC_OpPhi                    : I32EnumAttrCase<"OpPhi", 245>;
-def SPV_OC_OpLoopMerge              : I32EnumAttrCase<"OpLoopMerge", 246>;
-def SPV_OC_OpSelectionMerge         : I32EnumAttrCase<"OpSelectionMerge", 247>;
-def SPV_OC_OpLabel                  : I32EnumAttrCase<"OpLabel", 248>;
-def SPV_OC_OpBranch                 : I32EnumAttrCase<"OpBranch", 249>;
-def SPV_OC_OpBranchConditional      : I32EnumAttrCase<"OpBranchConditional", 250>;
-def SPV_OC_OpReturn                 : I32EnumAttrCase<"OpReturn", 253>;
-def SPV_OC_OpReturnValue            : I32EnumAttrCase<"OpReturnValue", 254>;
-def SPV_OC_OpUnreachable            : I32EnumAttrCase<"OpUnreachable", 255>;
-def SPV_OC_OpModuleProcessed        : I32EnumAttrCase<"OpModuleProcessed", 330>;
-
-def SPV_OpcodeAttr :
-    I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
-      SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource,
-      SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString,
-      SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
-      SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
-      SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
-      SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
-      SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
-      SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
-      SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
-      SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
-      SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
-      SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
-      SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
-      SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpConvertFToU,
-      SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
-      SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
-      SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
-      SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
-      SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
-      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
-      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
-      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
-      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
-      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
-      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
-      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
-      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
-      SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
-      SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
-      SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
-      SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
-      SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
-      SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
-      SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
-      SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
-      SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
-      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed
-      ]> {
-    let returnType = "::mlir::spirv::Opcode";
-    let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
-    let cppNamespace = "::mlir::spirv";
-}
-
-// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
-
-//===----------------------------------------------------------------------===//
-// SPIR-V type definitions
-//===----------------------------------------------------------------------===//
-
-def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
-def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
-def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
-def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
-
-// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
-// for the definition of the following types and type categories.
-
-def SPV_Void : TypeAlias;
-def SPV_Bool : IntOfWidths<[1]>;
-def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
-def SPV_Float : FloatOfWidths<[16, 32, 64]>;
-def SPV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPV_Vector : VectorOfLengthAndType<[2, 3, 4],
-                                       [SPV_Bool, SPV_Integer, SPV_Float]>;
-// Component type check is done in the type parser for the following SPIR-V
-// dialect-specific types so we use "Any" here.
-def SPV_AnyPtr : Type;
-def SPV_AnyArray : Type;
-def SPV_AnyRTArray : Type;
-def SPV_AnyStruct : Type;
-
-def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
-def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
-def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>;
-def SPV_Composite :
-    AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
-def SPV_Type : AnyTypeOf<[
-    SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
-    SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
-  ]>;
-
-class SPV_ScalarOrVectorOf :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;
-
-def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
-def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
-
-// TODO(antiagainst): Use a more appropriate way to model optional operands
-class SPV_Optional : Variadic;
-
-// TODO(ravishankarm): From 1.4, this should also include Composite type.
-def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
-
 //===----------------------------------------------------------------------===//
 // SPIR-V extension definitions
 //===----------------------------------------------------------------------===//
@@ -316,153 +86,6 @@ def SPV_ExtensionAttr :
 
 // Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
 
-def SPV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
-def SPV_AM_Physical32              : I32EnumAttrCase<"Physical32", 1>;
-def SPV_AM_Physical64              : I32EnumAttrCase<"Physical64", 2>;
-def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348>;
-
-def SPV_AddressingModelAttr :
-    I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
-      SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
-      SPV_AM_PhysicalStorageBuffer64
-    ]> {
-  let cppNamespace = "::mlir::spirv";
-}
-
-def SPV_BI_Position                    : I32EnumAttrCase<"Position", 0>;
-def SPV_BI_PointSize                   : I32EnumAttrCase<"PointSize", 1>;
-def SPV_BI_ClipDistance                : I32EnumAttrCase<"ClipDistance", 3>;
-def SPV_BI_CullDistance                : I32EnumAttrCase<"CullDistance", 4>;
-def SPV_BI_VertexId                    : I32EnumAttrCase<"VertexId", 5>;
-def SPV_BI_InstanceId                  : I32EnumAttrCase<"InstanceId", 6>;
-def SPV_BI_PrimitiveId                 : I32EnumAttrCase<"PrimitiveId", 7>;
-def SPV_BI_InvocationId                : I32EnumAttrCase<"InvocationId", 8>;
-def SPV_BI_Layer                       : I32EnumAttrCase<"Layer", 9>;
-def SPV_BI_ViewportIndex               : I32EnumAttrCase<"ViewportIndex", 10>;
-def SPV_BI_TessLevelOuter              : I32EnumAttrCase<"TessLevelOuter", 11>;
-def SPV_BI_TessLevelInner              : I32EnumAttrCase<"TessLevelInner", 12>;
-def SPV_BI_TessCoord                   : I32EnumAttrCase<"TessCoord", 13>;
-def SPV_BI_PatchVertices               : I32EnumAttrCase<"PatchVertices", 14>;
-def SPV_BI_FragCoord                   : I32EnumAttrCase<"FragCoord", 15>;
-def SPV_BI_PointCoord                  : I32EnumAttrCase<"PointCoord", 16>;
-def SPV_BI_FrontFacing                 : I32EnumAttrCase<"FrontFacing", 17>;
-def SPV_BI_SampleId                    : I32EnumAttrCase<"SampleId", 18>;
-def SPV_BI_SamplePosition              : I32EnumAttrCase<"SamplePosition", 19>;
-def SPV_BI_SampleMask                  : I32EnumAttrCase<"SampleMask", 20>;
-def SPV_BI_FragDepth                   : I32EnumAttrCase<"FragDepth", 22>;
-def SPV_BI_HelperInvocation            : I32EnumAttrCase<"HelperInvocation", 23>;
-def SPV_BI_NumWorkgroups               : I32EnumAttrCase<"NumWorkgroups", 24>;
-def SPV_BI_WorkgroupSize               : I32EnumAttrCase<"WorkgroupSize", 25>;
-def SPV_BI_WorkgroupId                 : I32EnumAttrCase<"WorkgroupId", 26>;
-def SPV_BI_LocalInvocationId           : I32EnumAttrCase<"LocalInvocationId", 27>;
-def SPV_BI_GlobalInvocationId          : I32EnumAttrCase<"GlobalInvocationId", 28>;
-def SPV_BI_LocalInvocationIndex        : I32EnumAttrCase<"LocalInvocationIndex", 29>;
-def SPV_BI_WorkDim                     : I32EnumAttrCase<"WorkDim", 30>;
-def SPV_BI_GlobalSize                  : I32EnumAttrCase<"GlobalSize", 31>;
-def SPV_BI_EnqueuedWorkgroupSize       : I32EnumAttrCase<"EnqueuedWorkgroupSize", 32>;
-def SPV_BI_GlobalOffset                : I32EnumAttrCase<"GlobalOffset", 33>;
-def SPV_BI_GlobalLinearId              : I32EnumAttrCase<"GlobalLinearId", 34>;
-def SPV_BI_SubgroupSize                : I32EnumAttrCase<"SubgroupSize", 36>;
-def SPV_BI_SubgroupMaxSize             : I32EnumAttrCase<"SubgroupMaxSize", 37>;
-def SPV_BI_NumSubgroups                : I32EnumAttrCase<"NumSubgroups", 38>;
-def SPV_BI_NumEnqueuedSubgroups        : I32EnumAttrCase<"NumEnqueuedSubgroups", 39>;
-def SPV_BI_SubgroupId                  : I32EnumAttrCase<"SubgroupId", 40>;
-def SPV_BI_SubgroupLocalInvocationId   : I32EnumAttrCase<"SubgroupLocalInvocationId", 41>;
-def SPV_BI_VertexIndex                 : I32EnumAttrCase<"VertexIndex", 42>;
-def SPV_BI_InstanceIndex               : I32EnumAttrCase<"InstanceIndex", 43>;
-def SPV_BI_SubgroupEqMask              : I32EnumAttrCase<"SubgroupEqMask", 4416>;
-def SPV_BI_SubgroupGeMask              : I32EnumAttrCase<"SubgroupGeMask", 4417>;
-def SPV_BI_SubgroupGtMask              : I32EnumAttrCase<"SubgroupGtMask", 4418>;
-def SPV_BI_SubgroupLeMask              : I32EnumAttrCase<"SubgroupLeMask", 4419>;
-def SPV_BI_SubgroupLtMask              : I32EnumAttrCase<"SubgroupLtMask", 4420>;
-def SPV_BI_BaseVertex                  : I32EnumAttrCase<"BaseVertex", 4424>;
-def SPV_BI_BaseInstance                : I32EnumAttrCase<"BaseInstance", 4425>;
-def SPV_BI_DrawIndex                   : I32EnumAttrCase<"DrawIndex", 4426>;
-def SPV_BI_DeviceIndex                 : I32EnumAttrCase<"DeviceIndex", 4438>;
-def SPV_BI_ViewIndex                   : I32EnumAttrCase<"ViewIndex", 4440>;
-def SPV_BI_BaryCoordNoPerspAMD         : I32EnumAttrCase<"BaryCoordNoPerspAMD", 4992>;
-def SPV_BI_BaryCoordNoPerspCentroidAMD : I32EnumAttrCase<"BaryCoordNoPerspCentroidAMD", 4993>;
-def SPV_BI_BaryCoordNoPerspSampleAMD   : I32EnumAttrCase<"BaryCoordNoPerspSampleAMD", 4994>;
-def SPV_BI_BaryCoordSmoothAMD          : I32EnumAttrCase<"BaryCoordSmoothAMD", 4995>;
-def SPV_BI_BaryCoordSmoothCentroidAMD  : I32EnumAttrCase<"BaryCoordSmoothCentroidAMD", 4996>;
-def SPV_BI_BaryCoordSmoothSampleAMD    : I32EnumAttrCase<"BaryCoordSmoothSampleAMD", 4997>;
-def SPV_BI_BaryCoordPullModelAMD       : I32EnumAttrCase<"BaryCoordPullModelAMD", 4998>;
-def SPV_BI_FragStencilRefEXT           : I32EnumAttrCase<"FragStencilRefEXT", 5014>;
-def SPV_BI_ViewportMaskNV              : I32EnumAttrCase<"ViewportMaskNV", 5253>;
-def SPV_BI_SecondaryPositionNV         : I32EnumAttrCase<"SecondaryPositionNV", 5257>;
-def SPV_BI_SecondaryViewportMaskNV     : I32EnumAttrCase<"SecondaryViewportMaskNV", 5258>;
-def SPV_BI_PositionPerViewNV           : I32EnumAttrCase<"PositionPerViewNV", 5261>;
-def SPV_BI_ViewportMaskPerViewNV       : I32EnumAttrCase<"ViewportMaskPerViewNV", 5262>;
-def SPV_BI_FullyCoveredEXT             : I32EnumAttrCase<"FullyCoveredEXT", 5264>;
-def SPV_BI_TaskCountNV                 : I32EnumAttrCase<"TaskCountNV", 5274>;
-def SPV_BI_PrimitiveCountNV            : I32EnumAttrCase<"PrimitiveCountNV", 5275>;
-def SPV_BI_PrimitiveIndicesNV          : I32EnumAttrCase<"PrimitiveIndicesNV", 5276>;
-def SPV_BI_ClipDistancePerViewNV       : I32EnumAttrCase<"ClipDistancePerViewNV", 5277>;
-def SPV_BI_CullDistancePerViewNV       : I32EnumAttrCase<"CullDistancePerViewNV", 5278>;
-def SPV_BI_LayerPerViewNV              : I32EnumAttrCase<"LayerPerViewNV", 5279>;
-def SPV_BI_MeshViewCountNV             : I32EnumAttrCase<"MeshViewCountNV", 5280>;
-def SPV_BI_MeshViewIndicesNV           : I32EnumAttrCase<"MeshViewIndicesNV", 5281>;
-def SPV_BI_BaryCoordNV                 : I32EnumAttrCase<"BaryCoordNV", 5286>;
-def SPV_BI_BaryCoordNoPerspNV          : I32EnumAttrCase<"BaryCoordNoPerspNV", 5287>;
-def SPV_BI_FragSizeEXT                 : I32EnumAttrCase<"FragSizeEXT", 5292>;
-def SPV_BI_FragInvocationCountEXT      : I32EnumAttrCase<"FragInvocationCountEXT", 5293>;
-def SPV_BI_LaunchIdNV                  : I32EnumAttrCase<"LaunchIdNV", 5319>;
-def SPV_BI_LaunchSizeNV                : I32EnumAttrCase<"LaunchSizeNV", 5320>;
-def SPV_BI_WorldRayOriginNV            : I32EnumAttrCase<"WorldRayOriginNV", 5321>;
-def SPV_BI_WorldRayDirectionNV         : I32EnumAttrCase<"WorldRayDirectionNV", 5322>;
-def SPV_BI_ObjectRayOriginNV           : I32EnumAttrCase<"ObjectRayOriginNV", 5323>;
-def SPV_BI_ObjectRayDirectionNV        : I32EnumAttrCase<"ObjectRayDirectionNV", 5324>;
-def SPV_BI_RayTminNV                   : I32EnumAttrCase<"RayTminNV", 5325>;
-def SPV_BI_RayTmaxNV                   : I32EnumAttrCase<"RayTmaxNV", 5326>;
-def SPV_BI_InstanceCustomIndexNV       : I32EnumAttrCase<"InstanceCustomIndexNV", 5327>;
-def SPV_BI_ObjectToWorldNV             : I32EnumAttrCase<"ObjectToWorldNV", 5330>;
-def SPV_BI_WorldToObjectNV             : I32EnumAttrCase<"WorldToObjectNV", 5331>;
-def SPV_BI_HitTNV                      : I32EnumAttrCase<"HitTNV", 5332>;
-def SPV_BI_HitKindNV                   : I32EnumAttrCase<"HitKindNV", 5333>;
-def SPV_BI_IncomingRayFlagsNV          : I32EnumAttrCase<"IncomingRayFlagsNV", 5351>;
-def SPV_BI_WarpsPerSMNV                : I32EnumAttrCase<"WarpsPerSMNV", 5374>;
-def SPV_BI_SMCountNV                   : I32EnumAttrCase<"SMCountNV", 5375>;
-def SPV_BI_WarpIDNV                    : I32EnumAttrCase<"WarpIDNV", 5376>;
-def SPV_BI_SMIDNV                      : I32EnumAttrCase<"SMIDNV", 5377>;
-
-def SPV_BuiltInAttr :
-    I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [
-      SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance,
-      SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId,
-      SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter,
-      SPV_BI_TessLevelInner, SPV_BI_TessCoord, SPV_BI_PatchVertices,
-      SPV_BI_FragCoord, SPV_BI_PointCoord, SPV_BI_FrontFacing, SPV_BI_SampleId,
-      SPV_BI_SamplePosition, SPV_BI_SampleMask, SPV_BI_FragDepth,
-      SPV_BI_HelperInvocation, SPV_BI_NumWorkgroups, SPV_BI_WorkgroupSize,
-      SPV_BI_WorkgroupId, SPV_BI_LocalInvocationId, SPV_BI_GlobalInvocationId,
-      SPV_BI_LocalInvocationIndex, SPV_BI_WorkDim, SPV_BI_GlobalSize,
-      SPV_BI_EnqueuedWorkgroupSize, SPV_BI_GlobalOffset, SPV_BI_GlobalLinearId,
-      SPV_BI_SubgroupSize, SPV_BI_SubgroupMaxSize, SPV_BI_NumSubgroups,
-      SPV_BI_NumEnqueuedSubgroups, SPV_BI_SubgroupId,
-      SPV_BI_SubgroupLocalInvocationId, SPV_BI_VertexIndex, SPV_BI_InstanceIndex,
-      SPV_BI_SubgroupEqMask, SPV_BI_SubgroupGeMask, SPV_BI_SubgroupGtMask,
-      SPV_BI_SubgroupLeMask, SPV_BI_SubgroupLtMask, SPV_BI_BaseVertex,
-      SPV_BI_BaseInstance, SPV_BI_DrawIndex, SPV_BI_DeviceIndex, SPV_BI_ViewIndex,
-      SPV_BI_BaryCoordNoPerspAMD, SPV_BI_BaryCoordNoPerspCentroidAMD,
-      SPV_BI_BaryCoordNoPerspSampleAMD, SPV_BI_BaryCoordSmoothAMD,
-      SPV_BI_BaryCoordSmoothCentroidAMD, SPV_BI_BaryCoordSmoothSampleAMD,
-      SPV_BI_BaryCoordPullModelAMD, SPV_BI_FragStencilRefEXT, SPV_BI_ViewportMaskNV,
-      SPV_BI_SecondaryPositionNV, SPV_BI_SecondaryViewportMaskNV,
-      SPV_BI_PositionPerViewNV, SPV_BI_ViewportMaskPerViewNV, SPV_BI_FullyCoveredEXT,
-      SPV_BI_TaskCountNV, SPV_BI_PrimitiveCountNV, SPV_BI_PrimitiveIndicesNV,
-      SPV_BI_ClipDistancePerViewNV, SPV_BI_CullDistancePerViewNV,
-      SPV_BI_LayerPerViewNV, SPV_BI_MeshViewCountNV, SPV_BI_MeshViewIndicesNV,
-      SPV_BI_BaryCoordNV, SPV_BI_BaryCoordNoPerspNV, SPV_BI_FragSizeEXT,
-      SPV_BI_FragInvocationCountEXT, SPV_BI_LaunchIdNV, SPV_BI_LaunchSizeNV,
-      SPV_BI_WorldRayOriginNV, SPV_BI_WorldRayDirectionNV, SPV_BI_ObjectRayOriginNV,
-      SPV_BI_ObjectRayDirectionNV, SPV_BI_RayTminNV, SPV_BI_RayTmaxNV,
-      SPV_BI_InstanceCustomIndexNV, SPV_BI_ObjectToWorldNV, SPV_BI_WorldToObjectNV,
-      SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV,
-      SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV
-    ]> {
-  let cppNamespace = "::mlir::spirv";
-}
-
 def SPV_C_Matrix                                    : I32EnumAttrCase<"Matrix", 0>;
 def SPV_C_Shader                                    : I32EnumAttrCase<"Shader", 1>;
 def SPV_C_Geometry                                  : I32EnumAttrCase<"Geometry", 2>;
@@ -671,6 +294,153 @@ def SPV_CapabilityAttr :
   let cppNamespace = "::mlir::spirv";
 }
 
+def SPV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
+def SPV_AM_Physical32              : I32EnumAttrCase<"Physical32", 1>;
+def SPV_AM_Physical64              : I32EnumAttrCase<"Physical64", 2>;
+def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348>;
+
+def SPV_AddressingModelAttr :
+    I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
+      SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
+      SPV_AM_PhysicalStorageBuffer64
+    ]> {
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_BI_Position                    : I32EnumAttrCase<"Position", 0>;
+def SPV_BI_PointSize                   : I32EnumAttrCase<"PointSize", 1>;
+def SPV_BI_ClipDistance                : I32EnumAttrCase<"ClipDistance", 3>;
+def SPV_BI_CullDistance                : I32EnumAttrCase<"CullDistance", 4>;
+def SPV_BI_VertexId                    : I32EnumAttrCase<"VertexId", 5>;
+def SPV_BI_InstanceId                  : I32EnumAttrCase<"InstanceId", 6>;
+def SPV_BI_PrimitiveId                 : I32EnumAttrCase<"PrimitiveId", 7>;
+def SPV_BI_InvocationId                : I32EnumAttrCase<"InvocationId", 8>;
+def SPV_BI_Layer                       : I32EnumAttrCase<"Layer", 9>;
+def SPV_BI_ViewportIndex               : I32EnumAttrCase<"ViewportIndex", 10>;
+def SPV_BI_TessLevelOuter              : I32EnumAttrCase<"TessLevelOuter", 11>;
+def SPV_BI_TessLevelInner              : I32EnumAttrCase<"TessLevelInner", 12>;
+def SPV_BI_TessCoord                   : I32EnumAttrCase<"TessCoord", 13>;
+def SPV_BI_PatchVertices               : I32EnumAttrCase<"PatchVertices", 14>;
+def SPV_BI_FragCoord                   : I32EnumAttrCase<"FragCoord", 15>;
+def SPV_BI_PointCoord                  : I32EnumAttrCase<"PointCoord", 16>;
+def SPV_BI_FrontFacing                 : I32EnumAttrCase<"FrontFacing", 17>;
+def SPV_BI_SampleId                    : I32EnumAttrCase<"SampleId", 18>;
+def SPV_BI_SamplePosition              : I32EnumAttrCase<"SamplePosition", 19>;
+def SPV_BI_SampleMask                  : I32EnumAttrCase<"SampleMask", 20>;
+def SPV_BI_FragDepth                   : I32EnumAttrCase<"FragDepth", 22>;
+def SPV_BI_HelperInvocation            : I32EnumAttrCase<"HelperInvocation", 23>;
+def SPV_BI_NumWorkgroups               : I32EnumAttrCase<"NumWorkgroups", 24>;
+def SPV_BI_WorkgroupSize               : I32EnumAttrCase<"WorkgroupSize", 25>;
+def SPV_BI_WorkgroupId                 : I32EnumAttrCase<"WorkgroupId", 26>;
+def SPV_BI_LocalInvocationId           : I32EnumAttrCase<"LocalInvocationId", 27>;
+def SPV_BI_GlobalInvocationId          : I32EnumAttrCase<"GlobalInvocationId", 28>;
+def SPV_BI_LocalInvocationIndex        : I32EnumAttrCase<"LocalInvocationIndex", 29>;
+def SPV_BI_WorkDim                     : I32EnumAttrCase<"WorkDim", 30>;
+def SPV_BI_GlobalSize                  : I32EnumAttrCase<"GlobalSize", 31>;
+def SPV_BI_EnqueuedWorkgroupSize       : I32EnumAttrCase<"EnqueuedWorkgroupSize", 32>;
+def SPV_BI_GlobalOffset                : I32EnumAttrCase<"GlobalOffset", 33>;
+def SPV_BI_GlobalLinearId              : I32EnumAttrCase<"GlobalLinearId", 34>;
+def SPV_BI_SubgroupSize                : I32EnumAttrCase<"SubgroupSize", 36>;
+def SPV_BI_SubgroupMaxSize             : I32EnumAttrCase<"SubgroupMaxSize", 37>;
+def SPV_BI_NumSubgroups                : I32EnumAttrCase<"NumSubgroups", 38>;
+def SPV_BI_NumEnqueuedSubgroups        : I32EnumAttrCase<"NumEnqueuedSubgroups", 39>;
+def SPV_BI_SubgroupId                  : I32EnumAttrCase<"SubgroupId", 40>;
+def SPV_BI_SubgroupLocalInvocationId   : I32EnumAttrCase<"SubgroupLocalInvocationId", 41>;
+def SPV_BI_VertexIndex                 : I32EnumAttrCase<"VertexIndex", 42>;
+def SPV_BI_InstanceIndex               : I32EnumAttrCase<"InstanceIndex", 43>;
+def SPV_BI_SubgroupEqMask              : I32EnumAttrCase<"SubgroupEqMask", 4416>;
+def SPV_BI_SubgroupGeMask              : I32EnumAttrCase<"SubgroupGeMask", 4417>;
+def SPV_BI_SubgroupGtMask              : I32EnumAttrCase<"SubgroupGtMask", 4418>;
+def SPV_BI_SubgroupLeMask              : I32EnumAttrCase<"SubgroupLeMask", 4419>;
+def SPV_BI_SubgroupLtMask              : I32EnumAttrCase<"SubgroupLtMask", 4420>;
+def SPV_BI_BaseVertex                  : I32EnumAttrCase<"BaseVertex", 4424>;
+def SPV_BI_BaseInstance                : I32EnumAttrCase<"BaseInstance", 4425>;
+def SPV_BI_DrawIndex                   : I32EnumAttrCase<"DrawIndex", 4426>;
+def SPV_BI_DeviceIndex                 : I32EnumAttrCase<"DeviceIndex", 4438>;
+def SPV_BI_ViewIndex                   : I32EnumAttrCase<"ViewIndex", 4440>;
+def SPV_BI_BaryCoordNoPerspAMD         : I32EnumAttrCase<"BaryCoordNoPerspAMD", 4992>;
+def SPV_BI_BaryCoordNoPerspCentroidAMD : I32EnumAttrCase<"BaryCoordNoPerspCentroidAMD", 4993>;
+def SPV_BI_BaryCoordNoPerspSampleAMD   : I32EnumAttrCase<"BaryCoordNoPerspSampleAMD", 4994>;
+def SPV_BI_BaryCoordSmoothAMD          : I32EnumAttrCase<"BaryCoordSmoothAMD", 4995>;
+def SPV_BI_BaryCoordSmoothCentroidAMD  : I32EnumAttrCase<"BaryCoordSmoothCentroidAMD", 4996>;
+def SPV_BI_BaryCoordSmoothSampleAMD    : I32EnumAttrCase<"BaryCoordSmoothSampleAMD", 4997>;
+def SPV_BI_BaryCoordPullModelAMD       : I32EnumAttrCase<"BaryCoordPullModelAMD", 4998>;
+def SPV_BI_FragStencilRefEXT           : I32EnumAttrCase<"FragStencilRefEXT", 5014>;
+def SPV_BI_ViewportMaskNV              : I32EnumAttrCase<"ViewportMaskNV", 5253>;
+def SPV_BI_SecondaryPositionNV         : I32EnumAttrCase<"SecondaryPositionNV", 5257>;
+def SPV_BI_SecondaryViewportMaskNV     : I32EnumAttrCase<"SecondaryViewportMaskNV", 5258>;
+def SPV_BI_PositionPerViewNV           : I32EnumAttrCase<"PositionPerViewNV", 5261>;
+def SPV_BI_ViewportMaskPerViewNV       : I32EnumAttrCase<"ViewportMaskPerViewNV", 5262>;
+def SPV_BI_FullyCoveredEXT             : I32EnumAttrCase<"FullyCoveredEXT", 5264>;
+def SPV_BI_TaskCountNV                 : I32EnumAttrCase<"TaskCountNV", 5274>;
+def SPV_BI_PrimitiveCountNV            : I32EnumAttrCase<"PrimitiveCountNV", 5275>;
+def SPV_BI_PrimitiveIndicesNV          : I32EnumAttrCase<"PrimitiveIndicesNV", 5276>;
+def SPV_BI_ClipDistancePerViewNV       : I32EnumAttrCase<"ClipDistancePerViewNV", 5277>;
+def SPV_BI_CullDistancePerViewNV       : I32EnumAttrCase<"CullDistancePerViewNV", 5278>;
+def SPV_BI_LayerPerViewNV              : I32EnumAttrCase<"LayerPerViewNV", 5279>;
+def SPV_BI_MeshViewCountNV             : I32EnumAttrCase<"MeshViewCountNV", 5280>;
+def SPV_BI_MeshViewIndicesNV           : I32EnumAttrCase<"MeshViewIndicesNV", 5281>;
+def SPV_BI_BaryCoordNV                 : I32EnumAttrCase<"BaryCoordNV", 5286>;
+def SPV_BI_BaryCoordNoPerspNV          : I32EnumAttrCase<"BaryCoordNoPerspNV", 5287>;
+def SPV_BI_FragSizeEXT                 : I32EnumAttrCase<"FragSizeEXT", 5292>;
+def SPV_BI_FragInvocationCountEXT      : I32EnumAttrCase<"FragInvocationCountEXT", 5293>;
+def SPV_BI_LaunchIdNV                  : I32EnumAttrCase<"LaunchIdNV", 5319>;
+def SPV_BI_LaunchSizeNV                : I32EnumAttrCase<"LaunchSizeNV", 5320>;
+def SPV_BI_WorldRayOriginNV            : I32EnumAttrCase<"WorldRayOriginNV", 5321>;
+def SPV_BI_WorldRayDirectionNV         : I32EnumAttrCase<"WorldRayDirectionNV", 5322>;
+def SPV_BI_ObjectRayOriginNV           : I32EnumAttrCase<"ObjectRayOriginNV", 5323>;
+def SPV_BI_ObjectRayDirectionNV        : I32EnumAttrCase<"ObjectRayDirectionNV", 5324>;
+def SPV_BI_RayTminNV                   : I32EnumAttrCase<"RayTminNV", 5325>;
+def SPV_BI_RayTmaxNV                   : I32EnumAttrCase<"RayTmaxNV", 5326>;
+def SPV_BI_InstanceCustomIndexNV       : I32EnumAttrCase<"InstanceCustomIndexNV", 5327>;
+def SPV_BI_ObjectToWorldNV             : I32EnumAttrCase<"ObjectToWorldNV", 5330>;
+def SPV_BI_WorldToObjectNV             : I32EnumAttrCase<"WorldToObjectNV", 5331>;
+def SPV_BI_HitTNV                      : I32EnumAttrCase<"HitTNV", 5332>;
+def SPV_BI_HitKindNV                   : I32EnumAttrCase<"HitKindNV", 5333>;
+def SPV_BI_IncomingRayFlagsNV          : I32EnumAttrCase<"IncomingRayFlagsNV", 5351>;
+def SPV_BI_WarpsPerSMNV                : I32EnumAttrCase<"WarpsPerSMNV", 5374>;
+def SPV_BI_SMCountNV                   : I32EnumAttrCase<"SMCountNV", 5375>;
+def SPV_BI_WarpIDNV                    : I32EnumAttrCase<"WarpIDNV", 5376>;
+def SPV_BI_SMIDNV                      : I32EnumAttrCase<"SMIDNV", 5377>;
+
+def SPV_BuiltInAttr :
+    I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [
+      SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance,
+      SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId,
+      SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter,
+      SPV_BI_TessLevelInner, SPV_BI_TessCoord, SPV_BI_PatchVertices,
+      SPV_BI_FragCoord, SPV_BI_PointCoord, SPV_BI_FrontFacing, SPV_BI_SampleId,
+      SPV_BI_SamplePosition, SPV_BI_SampleMask, SPV_BI_FragDepth,
+      SPV_BI_HelperInvocation, SPV_BI_NumWorkgroups, SPV_BI_WorkgroupSize,
+      SPV_BI_WorkgroupId, SPV_BI_LocalInvocationId, SPV_BI_GlobalInvocationId,
+      SPV_BI_LocalInvocationIndex, SPV_BI_WorkDim, SPV_BI_GlobalSize,
+      SPV_BI_EnqueuedWorkgroupSize, SPV_BI_GlobalOffset, SPV_BI_GlobalLinearId,
+      SPV_BI_SubgroupSize, SPV_BI_SubgroupMaxSize, SPV_BI_NumSubgroups,
+      SPV_BI_NumEnqueuedSubgroups, SPV_BI_SubgroupId,
+      SPV_BI_SubgroupLocalInvocationId, SPV_BI_VertexIndex, SPV_BI_InstanceIndex,
+      SPV_BI_SubgroupEqMask, SPV_BI_SubgroupGeMask, SPV_BI_SubgroupGtMask,
+      SPV_BI_SubgroupLeMask, SPV_BI_SubgroupLtMask, SPV_BI_BaseVertex,
+      SPV_BI_BaseInstance, SPV_BI_DrawIndex, SPV_BI_DeviceIndex, SPV_BI_ViewIndex,
+      SPV_BI_BaryCoordNoPerspAMD, SPV_BI_BaryCoordNoPerspCentroidAMD,
+      SPV_BI_BaryCoordNoPerspSampleAMD, SPV_BI_BaryCoordSmoothAMD,
+      SPV_BI_BaryCoordSmoothCentroidAMD, SPV_BI_BaryCoordSmoothSampleAMD,
+      SPV_BI_BaryCoordPullModelAMD, SPV_BI_FragStencilRefEXT, SPV_BI_ViewportMaskNV,
+      SPV_BI_SecondaryPositionNV, SPV_BI_SecondaryViewportMaskNV,
+      SPV_BI_PositionPerViewNV, SPV_BI_ViewportMaskPerViewNV, SPV_BI_FullyCoveredEXT,
+      SPV_BI_TaskCountNV, SPV_BI_PrimitiveCountNV, SPV_BI_PrimitiveIndicesNV,
+      SPV_BI_ClipDistancePerViewNV, SPV_BI_CullDistancePerViewNV,
+      SPV_BI_LayerPerViewNV, SPV_BI_MeshViewCountNV, SPV_BI_MeshViewIndicesNV,
+      SPV_BI_BaryCoordNV, SPV_BI_BaryCoordNoPerspNV, SPV_BI_FragSizeEXT,
+      SPV_BI_FragInvocationCountEXT, SPV_BI_LaunchIdNV, SPV_BI_LaunchSizeNV,
+      SPV_BI_WorldRayOriginNV, SPV_BI_WorldRayDirectionNV, SPV_BI_ObjectRayOriginNV,
+      SPV_BI_ObjectRayDirectionNV, SPV_BI_RayTminNV, SPV_BI_RayTmaxNV,
+      SPV_BI_InstanceCustomIndexNV, SPV_BI_ObjectToWorldNV, SPV_BI_WorldToObjectNV,
+      SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV,
+      SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV
+    ]> {
+  let cppNamespace = "::mlir::spirv";
+}
+
 def SPV_D_RelaxedPrecision            : I32EnumAttrCase<"RelaxedPrecision", 0>;
 def SPV_D_SpecId                      : I32EnumAttrCase<"SpecId", 1>;
 def SPV_D_Block                       : I32EnumAttrCase<"Block", 2>;
@@ -1101,7 +871,7 @@ def SPV_StorageClassAttr :
 
 // End enum section. Generated from SPIR-V spec; DO NOT MODIFY!
 
-// Enums added manually that are not part of SPIRV spec
+// Enums added manually that are not part of SPIR-V spec
 
 def SPV_IDI_NoDepth      : I32EnumAttrCase<"NoDepth", 0>;
 def SPV_IDI_IsDepth      : I32EnumAttrCase<"IsDepth", 1>;
@@ -1141,6 +911,54 @@ def SPV_SamplerUseAttr:
   let cppNamespace = "::mlir::spirv";
 }
 
+//===----------------------------------------------------------------------===//
+// SPIR-V type definitions
+//===----------------------------------------------------------------------===//
+
+def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
+def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
+def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
+def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
+
+// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
+// for the definition of the following types and type categories.
+
+def SPV_Void : TypeAlias;
+def SPV_Bool : IntOfWidths<[1]>;
+def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
+def SPV_Float : FloatOfWidths<[16, 32, 64]>;
+def SPV_Float16or32 : FloatOfWidths<[16, 32]>;
+def SPV_Vector : VectorOfLengthAndType<[2, 3, 4],
+                                       [SPV_Bool, SPV_Integer, SPV_Float]>;
+// Component type check is done in the type parser for the following SPIR-V
+// dialect-specific types so we use "Any" here.
+def SPV_AnyPtr : Type;
+def SPV_AnyArray : Type;
+def SPV_AnyRTArray : Type;
+def SPV_AnyStruct : Type;
+
+def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
+def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
+def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>;
+def SPV_Composite :
+    AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
+def SPV_Type : AnyTypeOf<[
+    SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
+    SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
+  ]>;
+
+class SPV_ScalarOrVectorOf :
+    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;
+
+def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
+def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
+
+// TODO(antiagainst): Use a more appropriate way to model optional operands
+class SPV_Optional : Variadic;
+
+// TODO(ravishankarm): From 1.4, this should also include Composite type.
+def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V OpTrait definitions
 //===----------------------------------------------------------------------===//
@@ -1155,6 +973,188 @@ def InModuleScope : PredOpTrait<
   "op must appear in a 'spv.module' block",
   CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>;
 
+//===----------------------------------------------------------------------===//
+// SPIR-V opcode specification
+//===----------------------------------------------------------------------===//
+
+class SPV_OpCode {
+  // Name used as reference to retrieve the opcode
+  string opname = name;
+
+  // Opcode associated with the name
+  int opcode = val;
+}
+
+// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
+
+def SPV_OC_OpNop                    : I32EnumAttrCase<"OpNop", 0>;
+def SPV_OC_OpUndef                  : I32EnumAttrCase<"OpUndef", 1>;
+def SPV_OC_OpSourceContinued        : I32EnumAttrCase<"OpSourceContinued", 2>;
+def SPV_OC_OpSource                 : I32EnumAttrCase<"OpSource", 3>;
+def SPV_OC_OpSourceExtension        : I32EnumAttrCase<"OpSourceExtension", 4>;
+def SPV_OC_OpName                   : I32EnumAttrCase<"OpName", 5>;
+def SPV_OC_OpMemberName             : I32EnumAttrCase<"OpMemberName", 6>;
+def SPV_OC_OpString                 : I32EnumAttrCase<"OpString", 7>;
+def SPV_OC_OpExtension              : I32EnumAttrCase<"OpExtension", 10>;
+def SPV_OC_OpExtInstImport          : I32EnumAttrCase<"OpExtInstImport", 11>;
+def SPV_OC_OpExtInst                : I32EnumAttrCase<"OpExtInst", 12>;
+def SPV_OC_OpMemoryModel            : I32EnumAttrCase<"OpMemoryModel", 14>;
+def SPV_OC_OpEntryPoint             : I32EnumAttrCase<"OpEntryPoint", 15>;
+def SPV_OC_OpExecutionMode          : I32EnumAttrCase<"OpExecutionMode", 16>;
+def SPV_OC_OpCapability             : I32EnumAttrCase<"OpCapability", 17>;
+def SPV_OC_OpTypeVoid               : I32EnumAttrCase<"OpTypeVoid", 19>;
+def SPV_OC_OpTypeBool               : I32EnumAttrCase<"OpTypeBool", 20>;
+def SPV_OC_OpTypeInt                : I32EnumAttrCase<"OpTypeInt", 21>;
+def SPV_OC_OpTypeFloat              : I32EnumAttrCase<"OpTypeFloat", 22>;
+def SPV_OC_OpTypeVector             : I32EnumAttrCase<"OpTypeVector", 23>;
+def SPV_OC_OpTypeArray              : I32EnumAttrCase<"OpTypeArray", 28>;
+def SPV_OC_OpTypeRuntimeArray       : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
+def SPV_OC_OpTypeStruct             : I32EnumAttrCase<"OpTypeStruct", 30>;
+def SPV_OC_OpTypePointer            : I32EnumAttrCase<"OpTypePointer", 32>;
+def SPV_OC_OpTypeFunction           : I32EnumAttrCase<"OpTypeFunction", 33>;
+def SPV_OC_OpConstantTrue           : I32EnumAttrCase<"OpConstantTrue", 41>;
+def SPV_OC_OpConstantFalse          : I32EnumAttrCase<"OpConstantFalse", 42>;
+def SPV_OC_OpConstant               : I32EnumAttrCase<"OpConstant", 43>;
+def SPV_OC_OpConstantComposite      : I32EnumAttrCase<"OpConstantComposite", 44>;
+def SPV_OC_OpConstantNull           : I32EnumAttrCase<"OpConstantNull", 46>;
+def SPV_OC_OpSpecConstantTrue       : I32EnumAttrCase<"OpSpecConstantTrue", 48>;
+def SPV_OC_OpSpecConstantFalse      : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
+def SPV_OC_OpSpecConstant           : I32EnumAttrCase<"OpSpecConstant", 50>;
+def SPV_OC_OpSpecConstantComposite  : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
+def SPV_OC_OpFunction               : I32EnumAttrCase<"OpFunction", 54>;
+def SPV_OC_OpFunctionParameter      : I32EnumAttrCase<"OpFunctionParameter", 55>;
+def SPV_OC_OpFunctionEnd            : I32EnumAttrCase<"OpFunctionEnd", 56>;
+def SPV_OC_OpFunctionCall           : I32EnumAttrCase<"OpFunctionCall", 57>;
+def SPV_OC_OpVariable               : I32EnumAttrCase<"OpVariable", 59>;
+def SPV_OC_OpLoad                   : I32EnumAttrCase<"OpLoad", 61>;
+def SPV_OC_OpStore                  : I32EnumAttrCase<"OpStore", 62>;
+def SPV_OC_OpAccessChain            : I32EnumAttrCase<"OpAccessChain", 65>;
+def SPV_OC_OpDecorate               : I32EnumAttrCase<"OpDecorate", 71>;
+def SPV_OC_OpMemberDecorate         : I32EnumAttrCase<"OpMemberDecorate", 72>;
+def SPV_OC_OpCompositeExtract       : I32EnumAttrCase<"OpCompositeExtract", 81>;
+def SPV_OC_OpConvertFToU            : I32EnumAttrCase<"OpConvertFToU", 109>;
+def SPV_OC_OpConvertFToS            : I32EnumAttrCase<"OpConvertFToS", 110>;
+def SPV_OC_OpConvertSToF            : I32EnumAttrCase<"OpConvertSToF", 111>;
+def SPV_OC_OpConvertUToF            : I32EnumAttrCase<"OpConvertUToF", 112>;
+def SPV_OC_OpUConvert               : I32EnumAttrCase<"OpUConvert", 113>;
+def SPV_OC_OpSConvert               : I32EnumAttrCase<"OpSConvert", 114>;
+def SPV_OC_OpFConvert               : I32EnumAttrCase<"OpFConvert", 115>;
+def SPV_OC_OpBitcast                : I32EnumAttrCase<"OpBitcast", 124>;
+def SPV_OC_OpFNegate                : I32EnumAttrCase<"OpFNegate", 127>;
+def SPV_OC_OpIAdd                   : I32EnumAttrCase<"OpIAdd", 128>;
+def SPV_OC_OpFAdd                   : I32EnumAttrCase<"OpFAdd", 129>;
+def SPV_OC_OpISub                   : I32EnumAttrCase<"OpISub", 130>;
+def SPV_OC_OpFSub                   : I32EnumAttrCase<"OpFSub", 131>;
+def SPV_OC_OpIMul                   : I32EnumAttrCase<"OpIMul", 132>;
+def SPV_OC_OpFMul                   : I32EnumAttrCase<"OpFMul", 133>;
+def SPV_OC_OpUDiv                   : I32EnumAttrCase<"OpUDiv", 134>;
+def SPV_OC_OpSDiv                   : I32EnumAttrCase<"OpSDiv", 135>;
+def SPV_OC_OpFDiv                   : I32EnumAttrCase<"OpFDiv", 136>;
+def SPV_OC_OpUMod                   : I32EnumAttrCase<"OpUMod", 137>;
+def SPV_OC_OpSRem                   : I32EnumAttrCase<"OpSRem", 138>;
+def SPV_OC_OpSMod                   : I32EnumAttrCase<"OpSMod", 139>;
+def SPV_OC_OpFRem                   : I32EnumAttrCase<"OpFRem", 140>;
+def SPV_OC_OpFMod                   : I32EnumAttrCase<"OpFMod", 141>;
+def SPV_OC_OpLogicalEqual           : I32EnumAttrCase<"OpLogicalEqual", 164>;
+def SPV_OC_OpLogicalNotEqual        : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
+def SPV_OC_OpLogicalOr              : I32EnumAttrCase<"OpLogicalOr", 166>;
+def SPV_OC_OpLogicalAnd             : I32EnumAttrCase<"OpLogicalAnd", 167>;
+def SPV_OC_OpLogicalNot             : I32EnumAttrCase<"OpLogicalNot", 168>;
+def SPV_OC_OpSelect                 : I32EnumAttrCase<"OpSelect", 169>;
+def SPV_OC_OpIEqual                 : I32EnumAttrCase<"OpIEqual", 170>;
+def SPV_OC_OpINotEqual              : I32EnumAttrCase<"OpINotEqual", 171>;
+def SPV_OC_OpUGreaterThan           : I32EnumAttrCase<"OpUGreaterThan", 172>;
+def SPV_OC_OpSGreaterThan           : I32EnumAttrCase<"OpSGreaterThan", 173>;
+def SPV_OC_OpUGreaterThanEqual      : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
+def SPV_OC_OpSGreaterThanEqual      : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
+def SPV_OC_OpULessThan              : I32EnumAttrCase<"OpULessThan", 176>;
+def SPV_OC_OpSLessThan              : I32EnumAttrCase<"OpSLessThan", 177>;
+def SPV_OC_OpULessThanEqual         : I32EnumAttrCase<"OpULessThanEqual", 178>;
+def SPV_OC_OpSLessThanEqual         : I32EnumAttrCase<"OpSLessThanEqual", 179>;
+def SPV_OC_OpFOrdEqual              : I32EnumAttrCase<"OpFOrdEqual", 180>;
+def SPV_OC_OpFUnordEqual            : I32EnumAttrCase<"OpFUnordEqual", 181>;
+def SPV_OC_OpFOrdNotEqual           : I32EnumAttrCase<"OpFOrdNotEqual", 182>;
+def SPV_OC_OpFUnordNotEqual         : I32EnumAttrCase<"OpFUnordNotEqual", 183>;
+def SPV_OC_OpFOrdLessThan           : I32EnumAttrCase<"OpFOrdLessThan", 184>;
+def SPV_OC_OpFUnordLessThan         : I32EnumAttrCase<"OpFUnordLessThan", 185>;
+def SPV_OC_OpFOrdGreaterThan        : I32EnumAttrCase<"OpFOrdGreaterThan", 186>;
+def SPV_OC_OpFUnordGreaterThan      : I32EnumAttrCase<"OpFUnordGreaterThan", 187>;
+def SPV_OC_OpFOrdLessThanEqual      : I32EnumAttrCase<"OpFOrdLessThanEqual", 188>;
+def SPV_OC_OpFUnordLessThanEqual    : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>;
+def SPV_OC_OpFOrdGreaterThanEqual   : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>;
+def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>;
+def SPV_OC_OpShiftRightLogical      : I32EnumAttrCase<"OpShiftRightLogical", 194>;
+def SPV_OC_OpShiftRightArithmetic   : I32EnumAttrCase<"OpShiftRightArithmetic", 195>;
+def SPV_OC_OpShiftLeftLogical       : I32EnumAttrCase<"OpShiftLeftLogical", 196>;
+def SPV_OC_OpBitwiseOr              : I32EnumAttrCase<"OpBitwiseOr", 197>;
+def SPV_OC_OpBitwiseXor             : I32EnumAttrCase<"OpBitwiseXor", 198>;
+def SPV_OC_OpBitwiseAnd             : I32EnumAttrCase<"OpBitwiseAnd", 199>;
+def SPV_OC_OpNot                    : I32EnumAttrCase<"OpNot", 200>;
+def SPV_OC_OpBitFieldInsert         : I32EnumAttrCase<"OpBitFieldInsert", 201>;
+def SPV_OC_OpBitFieldSExtract       : I32EnumAttrCase<"OpBitFieldSExtract", 202>;
+def SPV_OC_OpBitFieldUExtract       : I32EnumAttrCase<"OpBitFieldUExtract", 203>;
+def SPV_OC_OpBitReverse             : I32EnumAttrCase<"OpBitReverse", 204>;
+def SPV_OC_OpBitCount               : I32EnumAttrCase<"OpBitCount", 205>;
+def SPV_OC_OpControlBarrier         : I32EnumAttrCase<"OpControlBarrier", 224>;
+def SPV_OC_OpMemoryBarrier          : I32EnumAttrCase<"OpMemoryBarrier", 225>;
+def SPV_OC_OpPhi                    : I32EnumAttrCase<"OpPhi", 245>;
+def SPV_OC_OpLoopMerge              : I32EnumAttrCase<"OpLoopMerge", 246>;
+def SPV_OC_OpSelectionMerge         : I32EnumAttrCase<"OpSelectionMerge", 247>;
+def SPV_OC_OpLabel                  : I32EnumAttrCase<"OpLabel", 248>;
+def SPV_OC_OpBranch                 : I32EnumAttrCase<"OpBranch", 249>;
+def SPV_OC_OpBranchConditional      : I32EnumAttrCase<"OpBranchConditional", 250>;
+def SPV_OC_OpReturn                 : I32EnumAttrCase<"OpReturn", 253>;
+def SPV_OC_OpReturnValue            : I32EnumAttrCase<"OpReturnValue", 254>;
+def SPV_OC_OpUnreachable            : I32EnumAttrCase<"OpUnreachable", 255>;
+def SPV_OC_OpModuleProcessed        : I32EnumAttrCase<"OpModuleProcessed", 330>;
+
+def SPV_OpcodeAttr :
+    I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
+      SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource,
+      SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString,
+      SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
+      SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
+      SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
+      SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
+      SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
+      SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
+      SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
+      SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
+      SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
+      SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
+      SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
+      SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpConvertFToU,
+      SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
+      SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
+      SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
+      SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
+      SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
+      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
+      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
+      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
+      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
+      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
+      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
+      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
+      SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
+      SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
+      SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
+      SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
+      SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
+      SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
+      SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
+      SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
+      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed
+      ]> {
+    let returnType = "::mlir::spirv::Opcode";
+    let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
+    let cppNamespace = "::mlir::spirv";
+}
+
+// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
+
 //===----------------------------------------------------------------------===//
 // SPIR-V op definitions
 //===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py
index 9aed98dba70..5ef56675a1a 100755
--- a/third_party/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py
@@ -303,7 +303,10 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
   # Sort alphabetically according to enum name
   defs.sort(key=lambda enum : enum[0])
   # Only keep the definitions from now on
-  defs = [enum[1] for enum in defs]
+  # Put Capability's definition at the very beginning because capability cases
+  # will be referenced later
+  defs = [enum[1] for enum in defs if enum[0] == 'Capability'
+         ] + [enum[1] for enum in defs if enum[0] != 'Capability']
 
   # Substitute the old section
   content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \

From 6845eb81db73c26aba7fab67e4049606f668a831 Mon Sep 17 00:00:00 2001
From: Daniel Situnayake 
Date: Mon, 2 Dec 2019 14:53:21 -0800
Subject: [PATCH 176/279] Fix broken link

PiperOrigin-RevId: 283423464
Change-Id: I718a6b2666666b708e38e8933a6f26ad953beac1
---
 tensorflow/lite/g3doc/microcontrollers/index.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/lite/g3doc/microcontrollers/index.md b/tensorflow/lite/g3doc/microcontrollers/index.md
index 2ead371d4b4..64e80686116 100644
--- a/tensorflow/lite/g3doc/microcontrollers/index.md
+++ b/tensorflow/lite/g3doc/microcontrollers/index.md
@@ -36,7 +36,7 @@ There are example applications available for the following development boards:
 *   [SparkFun Edge](https://www.sparkfun.com/products/15170)
 *   [STM32F746 Discovery kit](https://www.st.com/en/evaluation-tools/32f746gdiscovery.html)
 *   [Adafruit EdgeBadge](https://www.adafruit.com/product/4400)
-*   [Adafruit TensorFlow Lite for Microcontrollers Kit]
+*   [Adafruit TensorFlow Lite for Microcontrollers Kit](https://www.adafruit.com/product/4317)
 
 To learn more about the libraries and examples, see
 [Get started with microcontrollers](get_started.md).

From 7fb4d81d6166c7ea35c604749bc05cd62799c73a Mon Sep 17 00:00:00 2001
From: Jian Li 
Date: Mon, 2 Dec 2019 14:54:00 -0800
Subject: [PATCH 177/279] Extend test framework to handle intermediate tensors.

PiperOrigin-RevId: 283423593
Change-Id: Ie6d321823636d49cf05e6ed29a8d220ae0476b0a
---
 tensorflow/lite/kernels/test_util.cc | 21 ++++++++++++++++++++-
 tensorflow/lite/kernels/test_util.h  |  4 ++++
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc
index 12cde4cc9d1..67cd514e1e8 100644
--- a/tensorflow/lite/kernels/test_util.cc
+++ b/tensorflow/lite/kernels/test_util.cc
@@ -88,6 +88,24 @@ int SingleOpModel::AddInput(const TensorData& t, bool is_variable) {
   return id;
 }
 
+int SingleOpModel::AddIntermediate(TensorType type,
+                                   const std::vector& scale,
+                                   const std::vector& zero_point) {
+  // Currently supports only int16 intermediate types.
+  // TODO(jianlijianli): make use of the type.
+  int id = tensors_.size();
+  flatbuffers::Offset q_params =
+      CreateQuantizationParameters(builder_, /*min=*/0, /*max=*/0,
+                                   builder_.CreateVector(scale),
+                                   builder_.CreateVector(zero_point));
+  tensors_.push_back(CreateTensor(builder_, builder_.CreateVector({}),
+                                  type,
+                                  /*buffer=*/0,
+                                  /*name=*/0, q_params, false));
+  intermediates_.push_back(id);
+  return id;
+}
+
 int SingleOpModel::AddNullInput() {
   int id = kTfLiteOptionalTensor;
   inputs_.push_back(id);
@@ -108,7 +126,8 @@ void SingleOpModel::SetBuiltinOp(BuiltinOperator type,
       builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_),
       builder_.CreateVector(outputs_), builtin_options_type,
       builtin_options,
-      /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS));
+      /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS, 0,
+      builder_.CreateVector(intermediates_)));
 }
 
 void SingleOpModel::SetCustomOp(
diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h
index 380d9b10e89..d9f3bc9d584 100644
--- a/tensorflow/lite/kernels/test_util.h
+++ b/tensorflow/lite/kernels/test_util.h
@@ -165,6 +165,9 @@ class SingleOpModel {
   }
   int AddInput(const TensorData& t, bool is_variable = false);
 
+  int AddIntermediate(TensorType type, const std::vector& scale,
+                      const std::vector& zero_point);
+
   // Templated version of AddConstInput().
   template 
   int AddConstInput(const TensorData& t, std::initializer_list data) {
@@ -587,6 +590,7 @@ class SingleOpModel {
 
   std::map tensor_data_;
   std::vector inputs_;
+  std::vector intermediates_;
   std::vector outputs_;
   std::vector> tensors_;
   std::vector> opcodes_;

From 8543b64d1a696f3c4348bced963feae1f0a907fc Mon Sep 17 00:00:00 2001
From: George Karpenkov 
Date: Mon, 2 Dec 2019 14:59:20 -0800
Subject: [PATCH 178/279] [XLA/CPU] Generalize all-reduce to multiple dense
 dimensions

Additionally, assert that on GPU we only have dense layouts.

PiperOrigin-RevId: 283424672
Change-Id: Ieac5368837ef057e1a4aa5fbcb4bbe61822e141a
---
 .../compiler/xla/service/cpu/cpu_runtime.cc   |   5 +-
 .../xla/service/gpu/nccl_all_reduce_thunk.cc  |   4 +-
 .../compiler/xla/tests/collective_ops_test.cc | 121 ++++++++++--------
 3 files changed, 72 insertions(+), 58 deletions(-)

diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 2cb15f6ec4d..9b3e85427a3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -25,6 +25,7 @@ limitations under the License.
 #include "absl/strings/str_join.h"
 #include "absl/synchronization/mutex.h"
 #include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
@@ -422,10 +423,10 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
 
   xla::Shape shape =
       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
+  CHECK(xla::LayoutUtil::IsDenseArray(shape))
+      << "All-reduce on CPU is implemented only for dense arrays";
 
   xla::AllReduceParticipantData participant(rendezvous_key);
-
-  CHECK_LE(shape.dimensions_size(), 1);
   participant.element_count = xla::ShapeUtil::ElementsIn(shape);
   participant.device_ordinal = device_ordinal;
   participant.primitive_type = shape.element_type();
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
index ac80552d032..d74e7f22916 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
@@ -29,6 +29,7 @@ limitations under the License.
 #include "absl/types/optional.h"
 #include "absl/types/span.h"
 #include "third_party/nccl/nccl.h"
+#include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -454,7 +455,8 @@ struct NcclAllReduceThunk::AuxData {
   return MatchReductionComputation(crs->to_apply()).has_value() &&
          DatatypeToNccl(AllReducePrimitiveType(crs)).has_value() &&
          crs->IsCrossReplicaAllReduce() &&
-         crs->operand_count() == 1;  // One array to reduce.
+         crs->operand_count() == 1 &&  // One array to reduce.
+         LayoutUtil::IsDenseArray(crs->operand(0)->shape());
 }
 
 /*static*/ absl::flat_hash_set
diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc
index 42f687a7996..8de508e876e 100644
--- a/tensorflow/compiler/xla/tests/collective_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc
@@ -41,7 +41,7 @@ using ::testing::UnorderedElementsAre;
 class CollectiveOpsTest : public HloTestBase {
  protected:
   std::unique_ptr MakeCrsModule(
-      int64 num_elems, std::vector> replica_groups,
+      const Shape& shape, std::vector> replica_groups,
       const HloModuleConfig& config, std::string op = "add",
       std::string datatype = "f32") {
     std::string hlo_template = R"(
@@ -54,11 +54,11 @@ class CollectiveOpsTest : public HloTestBase {
       }
 
       ENTRY test_computation {
-        p = DATATYPE[NUM_ELEMS] parameter(0)
-        p2 = DATATYPE[NUM_ELEMS] bitcast(p)
-        crs = DATATYPE[NUM_ELEMS] all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op
-        copy = DATATYPE[NUM_ELEMS] copy(crs)
-        ROOT out = DATATYPE[NUM_ELEMS] bitcast(copy)
+        p = SHAPE parameter(0)
+        p2 = SHAPE bitcast(p)
+        crs = SHAPE all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op
+        copy = SHAPE copy(crs)
+        ROOT out = SHAPE bitcast(copy)
       }
     )";
     std::vector replica_group_strs;
@@ -66,71 +66,70 @@ class CollectiveOpsTest : public HloTestBase {
       replica_group_strs.push_back(
           absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
     }
-    if (num_elems == 1) {
+    std::string shape_str = shape.ToString(/*print_layout=*/false);
+    if (shape_str == "f32[1]") {
       // Exercise the scalar codepath.
       hlo_template = absl::StrReplaceAll(
           hlo_template,
-          {{"DATATYPE[NUM_ELEMS] bitcast(p)", "DATATYPE[] bitcast(p)"},
-           {"DATATYPE[NUM_ELEMS] all-reduce", "DATATYPE[] all-reduce"},
-           {"DATATYPE[NUM_ELEMS] copy", "DATATYPE[] copy"}});
+          {{"DATATYPE[SHAPE] bitcast(p)", "DATATYPE[] bitcast(p)"},
+           {"DATATYPE[SHAPE] all-reduce", "DATATYPE[] all-reduce"},
+           {"DATATYPE[SHAPE] copy", "DATATYPE[] copy"}});
     }
-    return ParseAndReturnVerifiedModule(
-               absl::StrReplaceAll(
-                   hlo_template,
-                   {{"NUM_ELEMS", absl::StrCat(num_elems)},
-                    {"REPLICA_GROUPS",
-                     absl::StrFormat("{%s}",
-                                     absl::StrJoin(replica_group_strs, ", "))},
-                    {"OP", op},
-                    {"DATATYPE", datatype}}),
-               config)
-        .ValueOrDie();
+    std::string parameterized_hlo = absl::StrReplaceAll(
+        hlo_template,
+        {{"SHAPE", shape_str},
+         {"REPLICA_GROUPS",
+          absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "))},
+         {"OP", op},
+         {"DATATYPE", datatype}});
+    return ParseAndReturnVerifiedModule(parameterized_hlo, config).ValueOrDie();
   }
 
   template 
-  void TestTwoReplicasOneOperand(std::string op,
-                                 std::vector input_value,
-                                 std::vector expected_value) {
+  void TestTwoReplicasOneOperand(std::string op, Literal input_value,
+                                 Literal expected_value) {
     const int kNumReplicas = 2;
     std::string dtype = primitive_util::LowercasePrimitiveTypeName(
         primitive_util::NativeToPrimitiveType());
     auto config = GetModuleConfigForTest();
     config.set_replica_count(kNumReplicas);
-    auto module = MakeCrsModule(/*num_elems=*/input_value.size(),
-                                /*replica_groups=*/{}, config,
-                                /*op=*/op, /*datatype=*/dtype);
-    auto literal = LiteralUtil::CreateR1(input_value);
-    auto expected = LiteralUtil::CreateR1(expected_value);
+    auto module = MakeCrsModule(
+        /*shape_str=*/input_value.shape(),
+        /*replica_groups=*/{}, config,
+        /*op=*/op, /*datatype=*/dtype);
     TF_ASSERT_OK_AND_ASSIGN(std::vector results,
-                            ExecuteReplicated(std::move(module), {&literal},
+                            ExecuteReplicated(std::move(module), {&input_value},
                                               /*num_replicas=*/kNumReplicas,
                                               /*use_threads=*/true));
     for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
-      EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected, results[replica_idx],
-                                               ErrorSpec{1e-5, 1e-5}));
+      EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
+          expected_value, results[replica_idx], ErrorSpec{1e-5, 1e-5}));
     }
   }
 
   template 
   void TestAllOps() {
     auto cast = [&](int value) { return static_cast(value); };
-    std::vector input_value = {cast(1), cast(2), cast(3)};
+    auto to_literal = [&](absl::Span values) {
+      return LiteralUtil::CreateR1(values);
+    };
+    Literal input_value = to_literal({cast(1), cast(2), cast(3)});
     TestTwoReplicasOneOperand(
         "add",
-        /*input_value=*/input_value,
-        /*expected_value=*/{cast(2), cast(4), cast(6)});
+        /*input_value=*/input_value.Clone(),
+        /*expected_value=*/to_literal({cast(2), cast(4), cast(6)}));
     TestTwoReplicasOneOperand(
         "multiply",
-        /*input_value=*/input_value,
-        /*expected_value=*/{cast(1), cast(4), cast(9)});
+        /*input_value=*/input_value.Clone(),
+        /*expected_value=*/to_literal({cast(1), cast(4), cast(9)}));
     TestTwoReplicasOneOperand(
         "maximum",
-        /*input_value=*/input_value,
-        /*expected_value=*/{cast(1), cast(2), cast(3)});
+        /*input_value=*/input_value.Clone(),
+        /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
     TestTwoReplicasOneOperand(
         "minimum",
-        /*input_value=*/input_value,
-        /*expected_value=*/{cast(1), cast(2), cast(3)});
+        /*input_value=*/input_value.Clone(),
+        /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
   }
 };
 
@@ -169,10 +168,18 @@ static Eigen::half ToHalf(T value) {
   return static_cast(value);
 }
 
+XLA_TEST_F(CollectiveOpsTest, AllReduce_sum_float32_2D) {
+  TestTwoReplicasOneOperand(
+      "add",
+      /*input_value=*/LiteralUtil::CreateR2({{1, 2}, {3, 4}}),
+      /*expected_value=*/LiteralUtil::CreateR2({{2, 4}, {6, 8}}));
+}
+
 XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
-  TestTwoReplicasOneOperand("add",
-                                   /*input_value=*/{1},
-                                   /*expected_value=*/{2});
+  TestTwoReplicasOneOperand(
+      "add",
+      /*input_value=*/LiteralUtil::CreateR1({1}),
+      /*expected_value=*/LiteralUtil::CreateR1({2}));
 }
 
 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
@@ -227,12 +234,13 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
     config.set_replica_count(devices.size());
     config.set_static_device_assignment(device_assn);
 
-    auto module = MakeCrsModule(kNumElems, /*replica_groups=*/{}, config);
-
     std::vector input_vec(kNumElems);
     absl::c_iota(input_vec, 0);
     auto input_literal = LiteralUtil::CreateR1(input_vec);
 
+    auto module = MakeCrsModule(input_literal.shape(),
+                                /*replica_groups=*/{}, config);
+
     TF_ASSERT_OK_AND_ASSIGN(
         std::vector results,
         ExecuteReplicated(std::move(module), {&input_literal},
@@ -270,7 +278,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_NcclChannelCaching)) {
     auto config = GetModuleConfigForTest();
     config.set_replica_count(devices.size());
     config.set_static_device_assignment(e.device_assn);
-    auto module = MakeCrsModule(kNumElems, /*replica_groups=*/{}, config);
+    auto module = MakeCrsModule(input_literal.shape(),
+                                /*replica_groups=*/{}, config);
     e.executable =
         test_runner_
             .CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
@@ -325,20 +334,21 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
   const int64 kNumThreads = 200;
   const int64 kRunsPerThread = 10;
 
+  std::vector input_vec(kNumElems);
+  absl::c_iota(input_vec, 0);
+  auto input_literal = LiteralUtil::CreateR1(input_vec);
+
   auto config = GetModuleConfigForTest();
   config.set_replica_count(2);
   auto executable =
       test_runner_
-          .CreateExecutable(
-              MakeCrsModule(kNumElems, /*replica_groups=*/{}, config),
-              /*run_hlo_passes=*/true)
+          .CreateExecutable(MakeCrsModule(input_literal.shape(),
+                                          /*replica_groups=*/{}, config),
+                            /*run_hlo_passes=*/true)
           .ValueOrDie();
   std::vector devices = {0, 1};
   auto device_assn = MakeDeviceAssn(devices);
 
-  std::vector input_vec(kNumElems);
-  absl::c_iota(input_vec, 0);
-  auto input_literal = LiteralUtil::CreateR1(input_vec);
   HloRunner::ReplicatedExecuteOptions opts;
   opts.num_replicas = devices.size();
   opts.use_threads = true;
@@ -368,11 +378,12 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) {
 
   auto config = GetModuleConfigForTest();
   config.set_replica_count(4);
-  auto module = MakeCrsModule(/*num_elems=*/kNumElems,
-                              /*replica_groups=*/{{0}, {1, 2}, {3}}, config);
   std::vector input_vec(kNumElems);
   absl::c_iota(input_vec, 0);
   auto input_literal = LiteralUtil::CreateR1(input_vec);
+  auto module = MakeCrsModule(
+      /*shape_str=*/input_literal.shape(),
+      /*replica_groups=*/{{0}, {1, 2}, {3}}, config);
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::vector results,

From 6eafe7f72b2f980dc375d80f550243b2f4f6562d Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 15:03:39 -0800
Subject: [PATCH 179/279] disable the compact-operands printing to get the full
 strings of HLO ops

PiperOrigin-RevId: 283425766
Change-Id: I50e6be34bd4f2b291db6521b70435e5cbd9b3178
---
 tensorflow/compiler/xla/service/hlo_instruction.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 5e2e53ea6db..5855911650d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -113,7 +113,7 @@ class HloPrintOptions {
         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
         .set_print_metadata(false)
         .set_print_backend_config(false)
-        .set_compact_operands(true)
+        .set_compact_operands(false)
         .set_print_operand_names(false)
         .set_print_operand_shape(true)
         .set_print_program_shape(false)

From d5ee347de231b55f8ef7c11402db1673ff111d53 Mon Sep 17 00:00:00 2001
From: Alexandre Passos 
Date: Mon, 2 Dec 2019 15:08:31 -0800
Subject: [PATCH 180/279] Fix issue #31952 with passing IndexedSlices to
 backward functions.

PiperOrigin-RevId: 283426833
Change-Id: Ib2733f177d3446ceadc6f5da0da8929592286c02
---
 tensorflow/python/eager/backprop_test.py | 13 +++++++++++++
 tensorflow/python/eager/function.py      |  6 ++++++
 2 files changed, 19 insertions(+)

diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 23cfbd44972..62a808f44d7 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -307,6 +307,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
       y = array_ops.identity(x)
     self.assertEqual(t.gradient(y, x).numpy(), 1.0)
 
+  def testFunctionIndexedSlicesGradient(self):
+
+    @def_function.function
+    def f(x):
+      return x + 1
+
+    with backprop.GradientTape() as t:
+      x = constant_op.constant([1.0])
+      t.watch(x)
+      y = f(x)
+      y = array_ops.gather(y, [0])
+    self.assertAllEqual(t.gradient(y, x), [1.0])
+
   def testTapeGradientMultiTargetOneIsSource(self):
     x = constant_op.constant(2.0)
     with backprop.GradientTape() as t:
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 63263c03a97..2d8b442e1af 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1224,6 +1224,12 @@ class _TapeGradientFunctions(object):
       processed_args = []
       input_index = 0
       for output_index, arg in enumerate(args):
+        # Convert IndexedSlices to dense tensors. The IndexedSlices optimization
+        # is only really effective when doing tf.gather(variable) as the
+        # adjoint functions for most operations are unlikely to preserve the
+        # sparsity in IndexedSlices.
+        if isinstance(arg, ops.IndexedSlices):
+          arg = ops.convert_to_tensor(arg)
         if output_index in skip_positions:
           continue
         if arg is None:

From 55897e52bc3646eb57e19e20e56dd323637aa4f6 Mon Sep 17 00:00:00 2001
From: Jian Li 
Date: Mon, 2 Dec 2019 15:10:08 -0800
Subject: [PATCH 181/279] Add unit test for LSTM.

PiperOrigin-RevId: 283427129
Change-Id: Id46861e6cdfd867e055e8d984bb9d7bcfd4cc336
---
 tensorflow/lite/kernels/lstm_test.cc | 474 +++++++++++++++++++++++++++
 1 file changed, 474 insertions(+)

diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc
index d6a5f9a23cc..ac2f28cf278 100644
--- a/tensorflow/lite/kernels/lstm_test.cc
+++ b/tensorflow/lite/kernels/lstm_test.cc
@@ -2079,6 +2079,480 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
   VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
 }
 
+class LSTMIntegerOpModel : public SingleOpModel {
+ public:
+  LSTMIntegerOpModel(int n_batch, int n_input, int n_cell, int n_output,
+                     bool use_cifg, bool use_peephole,
+                     bool use_projection_weights, bool use_projection_bias,
+                     bool use_layer_norm, float cell_clip, float proj_clip,
+                     const std::vector>& input_shapes,
+                     const std::vector>& ranges,
+                     const std::vector>& intermediates)
+      : n_batch_(n_batch),
+        n_input_(n_input),
+        n_cell_(n_cell),
+        n_output_(n_output) {
+    EXPECT_EQ(input_shapes.size() + 1, ranges.size());
+    EXPECT_EQ(intermediates.size(), 5);
+    input_ = AddInput(
+        {TensorType_INT8, input_shapes[0], ranges[0].first, ranges[0].second});
+
+    if (use_cifg) {
+      input_to_input_weights_ = AddNullInput();
+    } else {
+      input_to_input_weights_ = AddInput({TensorType_INT8, input_shapes[1],
+                                          ranges[1].first, ranges[1].second});
+    }
+    input_to_forget_weights_ = AddInput(
+        {TensorType_INT8, input_shapes[2], ranges[2].first, ranges[2].second});
+    input_to_cell_weights_ = AddInput(
+        {TensorType_INT8, input_shapes[3], ranges[3].first, ranges[3].second});
+    input_to_output_weights_ = AddInput(
+        {TensorType_INT8, input_shapes[4], ranges[4].first, ranges[4].second});
+
+    if (use_cifg) {
+      recurrent_to_input_weights_ = AddNullInput();
+    } else {
+      recurrent_to_input_weights_ =
+          AddInput({TensorType_INT8, input_shapes[5], ranges[5].first,
+                    ranges[5].second});
+    }
+    recurrent_to_forget_weights_ = AddInput(
+        {TensorType_INT8, input_shapes[6], ranges[6].first, ranges[6].second});
+    recurrent_to_cell_weights_ = AddInput(
+        {TensorType_INT8, input_shapes[7], ranges[7].first, ranges[7].second});
+    recurrent_to_output_weights_ = AddInput(
+        {TensorType_INT8, input_shapes[8], ranges[8].first, ranges[8].second});
+
+    if (use_peephole) {
+      if (use_cifg) {
+        cell_to_input_weights_ = AddNullInput();
+      } else {
+        cell_to_input_weights_ = AddInput({TensorType_INT16, input_shapes[9],
+                                           ranges[9].first, ranges[9].second});
+      }
+      cell_to_forget_weights_ = AddInput({TensorType_INT16, input_shapes[10],
+                                          ranges[10].first, ranges[10].second});
+      cell_to_output_weights_ = AddInput({TensorType_INT8, input_shapes[11],
+                                          ranges[11].first, ranges[11].second});
+    } else {
+      cell_to_input_weights_ = AddNullInput();
+      cell_to_forget_weights_ = AddNullInput();
+      cell_to_output_weights_ = AddNullInput();
+    }
+
+    if (use_cifg) {
+      input_gate_bias_ = AddNullInput();
+    } else {
+      input_gate_bias_ = AddInput({TensorType_INT32, input_shapes[12],
+                                   ranges[12].first, ranges[12].second});
+    }
+    forget_gate_bias_ = AddInput({TensorType_INT32, input_shapes[13],
+                                  ranges[13].first, ranges[13].second});
+    cell_bias_ = AddInput({TensorType_INT32, input_shapes[14], ranges[14].first,
+                           ranges[14].second});
+    output_gate_bias_ = AddInput({TensorType_INT32, input_shapes[15],
+                                  ranges[15].first, ranges[15].second});
+
+    if (use_projection_weights) {
+      projection_weights_ = AddInput({TensorType_INT8, input_shapes[16],
+                                      ranges[16].first, ranges[16].second});
+      if (use_projection_bias) {
+        projection_bias_ = AddInput({TensorType_INT32, input_shapes[17],
+                                     ranges[17].first, ranges[17].second});
+      } else {
+        projection_bias_ = AddNullInput();
+      }
+    } else {
+      projection_weights_ = AddNullInput();
+      projection_bias_ = AddNullInput();
+    }
+
+    // Adding the 2 input state tensors.
+    input_activation_state_ = AddInput({TensorType_INT16, input_shapes[18],
+                                        ranges[18].first, ranges[18].second},
+                                       true);
+    input_cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
+                                  ranges[19].first, ranges[19].second},
+                                 true);
+
+    // Layer norm weights.
+    if (use_layer_norm) {
+      if (use_cifg) {
+        input_layer_norm_coefficients_ = AddNullInput();
+      } else {
+        input_layer_norm_coefficients_ =
+            AddInput({TensorType_INT16, input_shapes[20], ranges[20].first,
+                      ranges[20].second});
+      }
+      forget_layer_norm_coefficients_ =
+          AddInput({TensorType_INT16, input_shapes[21], ranges[21].first,
+                    ranges[21].second});
+      cell_layer_norm_coefficients_ =
+          AddInput({TensorType_INT16, input_shapes[22], ranges[22].first,
+                    ranges[22].second});
+      output_layer_norm_coefficients_ =
+          AddInput({TensorType_INT16, input_shapes[23], ranges[23].first,
+                    ranges[23].second});
+    }
+
+    for (int i = 0; i < intermediates.size(); ++i) {
+      intermediates_[i] =
+          AddIntermediate(TensorType_INT16, {intermediates[i].first},
+                          {intermediates[i].second});
+    }
+
+    output_ = AddOutput({TensorType_INT8,
+                         {n_batch, n_output},
+                         ranges[24].first,
+                         ranges[24].second});
+
+    SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
+                 CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
+                                   cell_clip, proj_clip)
+                     .Union());
+
+    // Do not apply delegate yet since tensor values are not known (and more
+    // specifically scales in quantized tensors are not known).
+    BuildInterpreter(input_shapes, /*allow_fp32_relax_to_fp16=*/false,
+                     /*apply_delegate=*/false);
+  }
+
+  void SetInputToInputWeights(const std::vector& f) {
+    QuantizeAndPopulate(input_to_input_weights_, f);
+  }
+
+  void SetInputToForgetWeights(const std::vector& f) {
+    QuantizeAndPopulate(input_to_forget_weights_, f);
+  }
+
+  void SetInputToCellWeights(const std::vector& f) {
+    QuantizeAndPopulate(input_to_cell_weights_, f);
+  }
+
+  void SetInputToOutputWeights(const std::vector& f) {
+    QuantizeAndPopulate(input_to_output_weights_, f);
+  }
+
+  void SetRecurrentToInputWeights(const std::vector& f) {
+    QuantizeAndPopulate(recurrent_to_input_weights_, f);
+  }
+
+  void SetRecurrentToForgetWeights(const std::vector& f) {
+    QuantizeAndPopulate(recurrent_to_forget_weights_, f);
+  }
+
+  void SetRecurrentToCellWeights(const std::vector& f) {
+    QuantizeAndPopulate(recurrent_to_cell_weights_, f);
+  }
+
+  void SetRecurrentToOutputWeights(const std::vector& f) {
+    QuantizeAndPopulate(recurrent_to_output_weights_, f);
+  }
+
+  void SetCellToInputWeights(const std::vector& f) {
+    QuantizeAndPopulate(cell_to_input_weights_, f);
+  }
+
+  void SetCellToForgetWeights(const std::vector& f) {
+    QuantizeAndPopulate(cell_to_forget_weights_, f);
+  }
+
+  void SetCellToOutputWeights(const std::vector& f) {
+    QuantizeAndPopulate(cell_to_output_weights_, f);
+  }
+
+  void SetInputLayerNormCoefficients(const std::vector& f) {
+    QuantizeAndPopulate(input_layer_norm_coefficients_, f);
+  }
+
+  void SetForgetLayerNormCoefficients(const std::vector& f) {
+    QuantizeAndPopulate(forget_layer_norm_coefficients_, f);
+  }
+
+  void SetCellLayerNormCoefficients(const std::vector& f) {
+    QuantizeAndPopulate(cell_layer_norm_coefficients_, f);
+  }
+
+  void SetOutputLayerNormCoefficients(const std::vector& f) {
+    QuantizeAndPopulate(output_layer_norm_coefficients_, f);
+  }
+
+  void SetInputGateBias(const std::vector& f) {
+    QuantizeAndPopulate(input_gate_bias_, f);
+  }
+
+  void SetForgetGateBias(const std::vector& f) {
+    QuantizeAndPopulate(forget_gate_bias_, f);
+  }
+
+  void SetCellBias(const std::vector& f) {
+    QuantizeAndPopulate(cell_bias_, f);
+  }
+
+  void SetOutputGateBias(const std::vector& f) {
+    QuantizeAndPopulate(output_gate_bias_, f);
+  }
+
+  void SetProjectionWeights(const std::vector& f) {
+    QuantizeAndPopulate(projection_weights_, f);
+  }
+
+  void SetProjectionBias(const std::vector& f) {
+    QuantizeAndPopulate(projection_bias_, f);
+  }
+
+  void SetInput(const std::vector& f) {
+    QuantizeAndPopulate(input_, f);
+  }
+
+  std::vector GetOutput() { return ExtractVector(output_); }
+
+  int num_inputs() { return n_input_; }
+  int num_outputs() { return n_output_; }
+  int num_cells() { return n_cell_; }
+  int num_batches() { return n_batch_; }
+
+ protected:
+  int input_;
+  int input_to_input_weights_;
+  int input_to_forget_weights_;
+  int input_to_cell_weights_;
+  int input_to_output_weights_;
+
+  int recurrent_to_input_weights_;
+  int recurrent_to_forget_weights_;
+  int recurrent_to_cell_weights_;
+  int recurrent_to_output_weights_;
+
+  int cell_to_input_weights_;
+  int cell_to_forget_weights_;
+  int cell_to_output_weights_;
+
+  int input_layer_norm_coefficients_;
+  int forget_layer_norm_coefficients_;
+  int cell_layer_norm_coefficients_;
+  int output_layer_norm_coefficients_;
+
+  int input_gate_bias_;
+  int forget_gate_bias_;
+  int cell_bias_;
+  int output_gate_bias_;
+
+  int projection_weights_;
+  int projection_bias_;
+  int input_activation_state_;
+  int input_cell_state_;
+
+  int intermediates_[5];
+
+  int output_;
+  int output_state_;
+  int cell_state_;
+
+  int n_batch_;
+  int n_input_;
+  int n_cell_;
+  int n_output_;
+};
+
+TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionNoPeephole) {
+  // Hyper parameters.
+  const int n_batch = 2;
+  const int n_input = 5;
+  const int n_cell = 4;
+  const int n_output = 3;
+  const float cell_clip = 0.0;
+  const float proj_clip = 0.0;
+
+  // Model related weights.
+  const std::vector input_to_input_weights = {
+      0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
+      -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
+
+  const std::vector input_to_forget_weights = {
+      -0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
+      -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
+
+  const std::vector input_to_cell_weights = {
+      -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
+      0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6};
+
+  const std::vector input_to_output_weights = {
+      -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
+      0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4};
+
+  const std::vector input_gate_bias = {0.03, 0.15, 0.22, 0.38};
+
+  const std::vector forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
+
+  const std::vector cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
+
+  const std::vector output_gate_bias = {0.05, -0.01, 0.2, 0.1};
+
+  const std::vector recurrent_to_input_weights = {
+      -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
+
+  const std::vector recurrent_to_cell_weights = {
+      -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
+
+  const std::vector recurrent_to_forget_weights = {
+      -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
+
+  const std::vector recurrent_to_output_weights = {
+      0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
+
+  const std::vector input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
+  const std::vector forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
+                                                             0.3};
+  const std::vector cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
+  const std::vector output_layer_norm_coefficients = {0.6, 0.2, 0.2,
+                                                             0.5};
+
+  const std::vector projection_weights = {
+      -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
+
+  // Input shapes.
+  const std::vector> inputs = {
+      {n_batch, n_input},  // input tensor
+
+      {n_cell, n_input},  // input_to_input_weight tensor
+      {n_cell, n_input},  // input_to_forget_weight tensor
+      {n_cell, n_input},  // input_to_cell_weight tensor
+      {n_cell, n_input},  // input_to_output_weight tensor
+
+      {n_cell, n_output},  // recurrent_to_input_weight tensor
+      {n_cell, n_output},  // recurrent_to_forget_weight tensor
+      {n_cell, n_output},  // recurrent_to_cell_weight tensor
+      {n_cell, n_output},  // recurrent_to_output_weight tensor
+
+      {0},  // cell_to_input_weight tensor
+      {0},  // cell_to_forget_weight tensor
+      {0},  // cell_to_output_weight tensor
+
+      {n_cell},  // input_gate_bias tensor
+      {n_cell},  // forget_gate_bias tensor
+      {n_cell},  // cell_bias tensor
+      {n_cell},  // output_gate_bias tensor
+
+      {n_output, n_cell},  // projection_weight tensor
+      {0},                 // projection_bias tensor
+
+      {n_batch, n_output},  // activation_state tensor
+      {n_batch, n_cell},    // cell_state tensor
+
+      {n_cell},  // input_layer_norm_coefficient tensor
+      {n_cell},  // forget_layer_norm_coefficient tensor
+      {n_cell},  // cell_layer_norm_coefficient tensor
+      {n_cell},  // output_layer_norm_coefficient tensor
+  };
+
+  // Input ranges.
+  const std::vector> ranges = {
+      {-1.0, 127.0 / 128},  // input tensor
+      {-1.0, 0.9},          // input_to_input_weight tensor
+      {-1.0, 1.0},          // input_to_forget_weight tensor
+      {-1.0, 1.0},          // input_to_cell_weight tensor
+      {-1.0, 0.8},          // input_to_output_weight tensor
+
+      {-0.8, 1.0},  // recurrent_to_input_weight tensor
+      {-0.8, 0.9},  // recurrent_to_forget_weight tensor
+      {-0.8, 1.0},  // recurrent_to_cell_weight tensor
+      {-1.0, 1.0},  // recurrent_to_output_weight tensor
+
+      {-1, 1},  // cell_to_input_weight tensor
+      {-1, 1},  // cell_to_forget_weight tensor
+      {-1, 1},  // cell_to_output_weight tensor
+
+      {-100, 100},  // input_gate_bias tensor
+      {-100, 80},   // forget_gate_bias tensor
+      {-100, 100},  // cell_bias tensor
+      {-100, 100},  // output_gate_bias tensor
+
+      {-0.5, 0.5},  // projection_weight tensor
+      {-1, 1},      // projection_bias tensor
+
+      {-1.0, 32767.0 / 32768},  // activation_state tensor
+      {-1, 1},                  // cell_state tensor
+
+      {0, 0.5},  // input_layer_norm_coefficient tensor
+      {0, 0.5},  // forget_layer_norm_coefficient tensor
+      {0, 1.0},  // cell_layer_norm_coefficient tensor
+      {0, 1.0},  // output_layer_norm_coefficient tensor
+      // Output scale is the same as input activation scale and only activation
+      // scale is used in the op, so this is only provided for clarity.
+      {-1.0, 32767.0 / 32768},  // output tensor.
+  };
+
+  // The scale and zero point of intermediate tensors.
+  std::vector> intermediates = {
+      {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
+
+  // Create model.
+  LSTMIntegerOpModel lstm(n_batch, n_input, n_cell, n_output,
+                          /*use_cifg=*/false, /*use_peephole=*/false,
+                          /*use_projection_weights=*/true,
+                          /*use_projection_bias=*/false,
+                          /*use_layer_norm=*/true, cell_clip, proj_clip, inputs,
+                          ranges, intermediates);
+
+  // Set weights.
+  lstm.SetInputToInputWeights(input_to_input_weights);
+  lstm.SetInputToCellWeights(input_to_cell_weights);
+  lstm.SetInputToForgetWeights(input_to_forget_weights);
+  lstm.SetInputToOutputWeights(input_to_output_weights);
+
+  lstm.SetInputGateBias(input_gate_bias);
+  lstm.SetCellBias(cell_gate_bias);
+  lstm.SetForgetGateBias(forget_gate_bias);
+  lstm.SetOutputGateBias(output_gate_bias);
+
+  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
+  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
+  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
+  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
+
+  lstm.SetProjectionWeights(projection_weights);
+
+  lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
+  lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
+  lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
+  lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
+
+  // Model inputs. sequence -batch - input
+  const std::vector> lstm_input = {
+      {
+          0.7, 0.8, 0.1, 0.2, 0.3,  //
+          0.8, 0.1, 0.2, 0.4, 0.5,  //
+      },
+      {
+          0.2, 0.7, 0.7, 0.1, 0.7,  //
+          0.3, 0.2, 0.9, 0.8, 0.1,  //
+      },
+      {
+          0.7, 0.8, 0.1, 0.2, 0.3,  //
+          0.3, 0.2, 0.9, 0.8, 0.1,  //
+      },
+  };
+
+  // Expected outputs.
+  const std::vector> expected_output = {
+      {107, 127, 127, -41, 127, 127},
+      {53, 127, 127, 22, 127, 127},
+      {90, 127, 127, 34, 127, 127},
+  };
+
+  // Invoke and verify the result.
+  const int input_sequence_size = lstm_input.size();
+  EXPECT_GT(input_sequence_size, 0);
+  for (int i = 0; i < input_sequence_size; ++i) {
+    lstm.SetInput(lstm_input[i]);
+    lstm.Invoke();
+    const auto x = lstm.GetOutput();
+    EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i]));
+  }
+}
+
 #ifdef GTEST_HAS_DEATH_TEST
 TEST(LSTMOpModel, InvalidTypeTest) {
   const int n_batch = 1;

From f9919d25507251b5487d66757c515d7ef29b43a4 Mon Sep 17 00:00:00 2001
From: Renjie Liu 
Date: Mon, 2 Dec 2019 15:16:51 -0800
Subject: [PATCH 182/279] Fix build, vcombine should really be unsigned. :(

PiperOrigin-RevId: 283428415
Change-Id: I4253d92e59594142bbe8233b9caea3c821ec045f
---
 tensorflow/lite/kernels/internal/optimized/optimized_ops.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index f8fc6113e61..26005e069a7 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -1002,14 +1002,14 @@ inline void MeanImpl(const tflite::MeanParams& op_params,
           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[3]));
 
       uint16x8_t combined_low =
-          vcombine_s16(narrowed_low_low, narrowed_high_low);
+          vcombine_u16(narrowed_low_low, narrowed_high_low);
       uint16x8_t combined_high =
-          vcombine_s16(narrowed_low_high, narrowed_high_high);
+          vcombine_u16(narrowed_low_high, narrowed_high_high);
 
       uint8x8_t narrowed_low = vmovn_u16(combined_low);
       uint8x8_t narrowed_high = vmovn_u16(combined_high);
 
-      uint8x16_t combined_output = vcombine_s8(narrowed_low, narrowed_high);
+      uint8x16_t combined_output = vcombine_u8(narrowed_low, narrowed_high);
 
       uint8_t* output_data_ptr =
           output_data + Offset(output_shape, out_b, 0, 0, out_d);

From 4970fccde24e628c6bfccf4d31b387b41d6a193a Mon Sep 17 00:00:00 2001
From: Smit Hinsu 
Date: Mon, 2 Dec 2019 15:22:28 -0800
Subject: [PATCH 183/279] Wrap legalize_tf and hlo_ops within namespaces
 instead of using `using` directives

This is to have consistency with other files and to follow Google code style.

Also,
* Namespace 'xla' members are referred with qualified name lookup within 'xla_hlo'
  namespace.
* Wrapped hlo_utils and convert_op_folder within namespace mlir. Now, all files except
  related to import/export are within namespace mlir.

PiperOrigin-RevId: 283429443
Change-Id: I69ef723598e10432f244eeb117cff9a1cccf153d
---
 .../compiler/mlir/xla/convert_op_folder.cc    |   2 +
 .../compiler/mlir/xla/convert_op_folder.h     |   2 +
 tensorflow/compiler/mlir/xla/ir/hlo_ops.cc    |  17 +-
 tensorflow/compiler/mlir/xla/ir/hlo_utils.td  |   4 +-
 .../mlir/xla/transforms/legalize_tf.cc        | 252 +++++++++---------
 .../xla/transforms/legalize_tf_patterns.td    |   2 +-
 6 files changed, 136 insertions(+), 143 deletions(-)

diff --git a/tensorflow/compiler/mlir/xla/convert_op_folder.cc b/tensorflow/compiler/mlir/xla/convert_op_folder.cc
index d26bec292cc..8245b4a0585 100644
--- a/tensorflow/compiler/mlir/xla/convert_op_folder.cc
+++ b/tensorflow/compiler/mlir/xla/convert_op_folder.cc
@@ -21,6 +21,7 @@ limitations under the License.
 #include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/IR/TypeUtilities.h"  // TF:local_config_mlir
 
+namespace mlir {
 namespace xla {
 
 mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
@@ -82,3 +83,4 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
 }
 
 }  // namespace xla
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/convert_op_folder.h b/tensorflow/compiler/mlir/xla/convert_op_folder.h
index 1c3f75489f8..63ac0e61df5 100644
--- a/tensorflow/compiler/mlir/xla/convert_op_folder.h
+++ b/tensorflow/compiler/mlir/xla/convert_op_folder.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include "mlir/IR/Attributes.h"  // TF:local_config_mlir
 #include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 
+namespace mlir {
 namespace xla {
 
 // Converts the given elements attr to the specified elements type.
@@ -27,5 +28,6 @@ namespace xla {
 mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
                                        mlir::Type new_type);
 }  // namespace xla
+}  // namespace mlir
 
 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index fb74d509ffb..639c85c48b5 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -50,10 +50,8 @@ limitations under the License.
 
 namespace mlir {
 #include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
-}  // namespace mlir
 
-using namespace mlir;
-using namespace mlir::xla_hlo;
+namespace xla_hlo {
 
 Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
                                               Attribute value, Type type,
@@ -212,9 +210,9 @@ void AbsOp::build(Builder* builder, OperationState& result, Value* operand) {
     new_type = operand->getType();
   } else if (shaped_type.hasRank()) {
     new_type =
-        mlir::RankedTensorType::get(shaped_type.getShape(), operand->getType());
+        RankedTensorType::get(shaped_type.getShape(), operand->getType());
   } else {
-    new_type = mlir::UnrankedTensorType::get(operand->getType());
+    new_type = UnrankedTensorType::get(operand->getType());
   }
 
   return AbsOp::build(builder, result, new_type, operand);
@@ -241,8 +239,8 @@ OpFoldResult ConvertOp::fold(ArrayRef operands) {
 
   // If the operand is constant, we can do the conversion now.
   if (auto elementsAttr = operands.front().dyn_cast_or_null()) {
-    return ::xla::ConvertElementsAttr(elementsAttr,
-                                      getElementTypeOrSelf(getResult()));
+    return xla::ConvertElementsAttr(elementsAttr,
+                                    getElementTypeOrSelf(getResult()));
   }
 
   return {};
@@ -436,7 +434,7 @@ static LogicalResult Verify(ClampOp op) {
 void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs,
                       Value* rhs) {
   auto type = lhs->getType();
-  auto element_ty = mlir::ComplexType::get(getElementTypeOrSelf(type));
+  auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
   Type result_ty;
   if (auto ranked_type = type.dyn_cast()) {
     result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
@@ -990,3 +988,6 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context)
   // Support unknown operations because not all XLA operations are registered.
   // allowUnknownOperations();
 }
+
+}  // namespace xla_hlo
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td
index bd1a448b80f..1a56d230d0d 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td
@@ -27,11 +27,11 @@ def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
 def CastIntElementsAttr : NativeCodeCall<"$0.cast()">;
 
 class ConstantSplat : NativeCodeCall<
-    "getSplat(&$_builder, $0, " # value # ")">;
+    "xla::getSplat(&$_builder, $0, " # value # ")">;
 
 def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
 
 def BinBroadcastDimensions : NativeCodeCall<
-    "getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
+    "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
 
 #endif // HLO_UTILS
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index b1e94927ecb..a156685f005 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -47,9 +47,10 @@ limitations under the License.
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 
-using namespace mlir;
-
+namespace mlir {
+namespace xla_hlo {
 namespace {
+
 class LegalizeTF : public FunctionPass {
  public:
   struct Options : public PassOptions {
@@ -72,12 +73,6 @@ class LegalizeTF : public FunctionPass {
  private:
   bool allow_partial_conversion_;
 };
-}  // end anonymous namespace
-
-std::unique_ptr>
-mlir::xla_hlo::createLegalizeTFPass(bool allow_partial_conversion) {
-  return std::make_unique(allow_partial_conversion);
-}
 
 /// Returns if the given TF data format string is the default format.
 static bool isDefaultDataFormat(StringRef format) { return format == "NHWC"; }
@@ -131,10 +126,9 @@ static llvm::Optional GetIntegerHLOAxisFromTFAxis(Value *value,
 
 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
 /// the shape of the input value.
-static xla_hlo::ConvertOp CastElementsToI64(Location loc, Value *value,
-                                            PatternRewriter *rewriter) {
-  return rewriter->create(loc, value,
-                                              rewriter->getIntegerType(64));
+static ConvertOp CastElementsToI64(Location loc, Value *value,
+                                   PatternRewriter *rewriter) {
+  return rewriter->create(loc, value, rewriter->getIntegerType(64));
 }
 
 // Returns size of dimension at the specified index, if ranked tensor.
@@ -155,8 +149,8 @@ tensorflow::TensorShape ToTensorShape(llvm::ArrayRef sizes) {
 }
 
 // Returns minimum value for the given int or float element type.
-static xla_hlo::ConstOp GetMinValueForType(Type ty, Location loc,
-                                           PatternRewriter *rewriter) {
+static ConstOp GetMinValueForType(Type ty, Location loc,
+                                  PatternRewriter *rewriter) {
   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
 
   DenseElementsAttr attr;
@@ -169,14 +163,13 @@ static xla_hlo::ConstOp GetMinValueForType(Type ty, Location loc,
     APInt min_val = APInt::getSignedMinValue(int_ty.getWidth());
     attr = DenseElementsAttr::get(scalar_ty, min_val);
   }
-  return rewriter->create(loc, attr);
+  return rewriter->create(loc, attr);
 }
 
 // Returns int or float scalar DenseElementsAttr attribute with the given
 // element type and the value.
-static xla_hlo::ConstOp GetScalarOfType(Type ty, Location loc,
-                                        int64_t raw_value,
-                                        PatternRewriter *rewriter) {
+static ConstOp GetScalarOfType(Type ty, Location loc, int64_t raw_value,
+                               PatternRewriter *rewriter) {
   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
 
   DenseElementsAttr attr;
@@ -188,7 +181,7 @@ static xla_hlo::ConstOp GetScalarOfType(Type ty, Location loc,
     APInt value(int_ty.getWidth(), static_cast(raw_value), true);
     attr = DenseElementsAttr::get(scalar_ty, value);
   }
-  return rewriter->create(loc, attr);
+  return rewriter->create(loc, attr);
 }
 
 // Builds body for reduce op by using the using the template binary op as the
@@ -207,7 +200,7 @@ static void BuildReduceBody(Type element_type, Region *body,
   auto reducer =
       builder->create(loc, block->getArgument(0), block->getArgument(1),
                           /*broadcast_dimensions=*/nullptr);
-  builder->create(loc, reducer.getResult());
+  builder->create(loc, reducer.getResult());
 }
 
 //===----------------------------------------------------------------------===//
@@ -360,17 +353,17 @@ static void BuildArgMinMaxReductionBody(Type input_element_type,
   Location loc = body->getLoc();
   StringAttr compare_direction =
       StringAttr::get(direction, builder->getContext());
-  Value *compare = builder->create(
+  Value *compare = builder->create(
       loc, block->getArgument(0), block->getArgument(2),
       /*broadcast_dimensions=*/nullptr, compare_direction);
 
-  Value *selected_input = builder->create(
+  Value *selected_input = builder->create(
       loc, input_type, compare, block->getArgument(0), block->getArgument(2));
-  Value *selected_index = builder->create(
+  Value *selected_index = builder->create(
       loc, index_type, compare, block->getArgument(1), block->getArgument(3));
 
   Value *return_values[] = {selected_input, selected_index};
-  builder->create(loc, return_values);
+  builder->create(loc, return_values);
 }
 
 //===----------------------------------------------------------------------===//
@@ -446,10 +439,6 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
 // Op converters.
 //===----------------------------------------------------------------------===//
 
-namespace mlir {
-namespace xla {
-namespace {
-
 NamedAttribute GetConvDimensionNumbersAttr(
     ArrayRef spatial_dim_indices, tensorflow::TensorFormat format,
     Builder *builder) {
@@ -474,7 +463,7 @@ NamedAttribute GetConvDimensionNumbersAttr(
 
   return builder->getNamedAttr(
       "dimension_numbers",
-      mlir::xla_hlo::ConvDimensionNumbers::get(
+      ConvDimensionNumbers::get(
           batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim,
           kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim,
           feature_dim, spatial_dims, builder->getContext()));
@@ -602,8 +591,8 @@ class ConvertConv : public OpRewritePattern {
     NamedAttribute attrs[] = {rhs_dilations_attr,     window_strides_attr,
                               dimension_numbers_attr, feature_group_count_attr,
                               batch_group_count_attr, paddings_attr};
-    rewriter.replaceOpWithNewOp(op, op.getType(), operands,
-                                                 llvm::makeArrayRef(attrs));
+    rewriter.replaceOpWithNewOp(op, op.getType(), operands,
+                                        llvm::makeArrayRef(attrs));
     return Pattern::matchSuccess();
   }
 };
@@ -635,18 +624,16 @@ class ConvertBF16FloorDivOp : public OpRewritePattern {
 
     auto out_type = op.z()->getType().cast();
 
-    l = rewriter.create(op.getLoc(), l,
-                                            rewriter.getF32Type());
-    r = rewriter.create(op.getLoc(), r,
-                                            rewriter.getF32Type());
+    l = rewriter.create(op.getLoc(), l, rewriter.getF32Type());
+    r = rewriter.create(op.getLoc(), r, rewriter.getF32Type());
 
     auto intermediate = rewriter.create(
         op.getLoc(),
         ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
         r);
 
-    auto floor_op = rewriter.create(op.getLoc(), out_type,
-                                                        intermediate);
+    auto floor_op =
+        rewriter.create(op.getLoc(), out_type, intermediate);
     rewriter.replaceOp(op, floor_op.getResult());
     return Pattern::matchSuccess();
   }
@@ -674,15 +661,15 @@ class ConvertMaxPoolOp : public OpRewritePattern {
         op.input()->getType().cast().getElementType();
     if (!element_type.isIntOrFloat()) return matchFailure();
     Location loc = op.getLoc();
-    xla_hlo::ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
+    ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
 
-    auto reduce = rewriter.create(
+    auto reduce = rewriter.create(
         loc, op.getType(), op.input(), init.getResult(),
         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
         /*base_dilations=*/DenseIntElementsAttr(),
         /*window_dilations=*/DenseIntElementsAttr(),
         /*paddings=*/DenseIntElementsAttr());
-    BuildReduceBody(element_type, &reduce.body(), &rewriter);
+    BuildReduceBody(element_type, &reduce.body(), &rewriter);
 
     rewriter.replaceOp(op, reduce.getResult());
     return matchSuccess();
@@ -717,28 +704,28 @@ class ConvertSigmoidOp : public OpRewritePattern {
                                      PatternRewriter &rewriter) const override {
     auto operand = op.getOperand();
 
-    auto scalar_one = rewriter.create(
+    auto scalar_one = rewriter.create(
         op.getLoc(),
         rewriter.getFloatAttr(getElementTypeOrSelf(operand->getType()), 0.5));
 
     auto shaped_type = operand->getType().cast();
-    auto constant_ones = rewriter.create(
+    auto constant_ones = rewriter.create(
         op.getLoc(), shaped_type, scalar_one,
         DenseIntElementsAttr::get(
             RankedTensorType::get({shaped_type.getRank()},
                                   rewriter.getIntegerType(64)),
             shaped_type.getShape()));
 
-    auto scaled_input = rewriter.create(
+    auto scaled_input = rewriter.create(
         op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
-    auto tanh_op = rewriter.create(
-        op.getLoc(), operand->getType(), scaled_input);
-    auto mul_op = rewriter.create(
-        op.getLoc(), tanh_op, constant_ones,
-        /*DenseIntElementsAttr=*/DenseIntElementsAttr());
-    auto add_op = rewriter.create(
-        op.getLoc(), mul_op, constant_ones,
-        /*DenseIntElementsAttr=*/DenseIntElementsAttr());
+    auto tanh_op =
+        rewriter.create(op.getLoc(), operand->getType(), scaled_input);
+    auto mul_op =
+        rewriter.create(op.getLoc(), tanh_op, constant_ones,
+                               /*DenseIntElementsAttr=*/DenseIntElementsAttr());
+    auto add_op =
+        rewriter.create(op.getLoc(), mul_op, constant_ones,
+                               /*DenseIntElementsAttr=*/DenseIntElementsAttr());
 
     rewriter.replaceOp(op, add_op.getResult());
     return matchSuccess();
@@ -803,11 +790,11 @@ class ConvertSoftmaxOp : public OpRewritePattern {
     auto max_logits =
         rewriter.create(loc, logits, reduce_dim,
                                    /*keep_dims=*/rewriter.getBoolAttr(false));
-    auto shifted_logits = rewriter.create(
-        loc, type, logits, max_logits, batch_dims);
+    auto shifted_logits =
+        rewriter.create(loc, type, logits, max_logits, batch_dims);
 
     // Exponentiate the inputs.
-    Value *exp = rewriter.create(loc, type, shifted_logits);
+    Value *exp = rewriter.create(loc, type, shifted_logits);
 
     // Compute summation of the exponentials.
     auto exp_sum =
@@ -816,11 +803,10 @@ class ConvertSoftmaxOp : public OpRewritePattern {
     Value *sum = exp_sum.getResult();
 
     if (use_log) {
-      Value *log = rewriter.create(loc, sum);
-      rewriter.replaceOpWithNewOp(op, shifted_logits, log,
-                                                  batch_dims);
+      Value *log = rewriter.create(loc, sum);
+      rewriter.replaceOpWithNewOp(op, shifted_logits, log, batch_dims);
     } else {
-      rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims);
+      rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims);
     }
     return Pattern::matchSuccess();
   }
@@ -864,10 +850,10 @@ class ConvertSizeOp : public OpRewritePattern {
         GetScalarOfType(result_type.cast().getElementType(),
                         op.getLoc(), 1, &rewriter);
     for (int64_t i = 0; i < rank; ++i) {
-      auto dim = rewriter.create(
+      auto dim = rewriter.create(
           op.getLoc(), result_type, input,
           rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
-      size = rewriter.create(
+      size = rewriter.create(
           op.getLoc(), size->getResult(0), dim.getResult(),
           /*DenseIntElementsAttr=*/DenseIntElementsAttr());
     }
@@ -953,11 +939,11 @@ class ConvertSplitOp : public OpRewritePattern {
     for (int i = 0; i < num_splits; ++i) {
       begin_indices[dim_index] = i * slice_size;
       end_indices[dim_index] = (i + 1) * slice_size;
-      slices.push_back(rewriter.create(
-          op.getLoc(), slice_type, op.value(),
-          GetI64ElementsAttr(begin_indices, &rewriter),
-          GetI64ElementsAttr(end_indices, &rewriter),
-          GetI64ElementsAttr(strides, &rewriter)));
+      slices.push_back(
+          rewriter.create(op.getLoc(), slice_type, op.value(),
+                                   GetI64ElementsAttr(begin_indices, &rewriter),
+                                   GetI64ElementsAttr(end_indices, &rewriter),
+                                   GetI64ElementsAttr(strides, &rewriter)));
     }
 
     rewriter.replaceOp(op, slices);
@@ -1059,10 +1045,10 @@ class ConvertStridedSliceOp : public OpRewritePattern {
     }
 
     Location loc = op.getLoc();
-    auto reversed = rewriter.create(
+    auto reversed = rewriter.create(
         loc, input_ty, op.input(),
         GetI64ElementsAttr(dims_to_reverse, &rewriter));
-    auto sliced = rewriter.create(
+    auto sliced = rewriter.create(
         loc, reversed.getResult(),
         GetI64ElementsAttr(hlo_begin_indices, &rewriter),
         GetI64ElementsAttr(hlo_end_indices, &rewriter),
@@ -1070,7 +1056,7 @@ class ConvertStridedSliceOp : public OpRewritePattern {
 
     // Reshape slice result so that the shape is updated depending on
     // 'new_axis_mask' or 'shrink_axis_mask' attributes.
-    rewriter.replaceOpWithNewOp(op, op.getType(), sliced);
+    rewriter.replaceOpWithNewOp(op, op.getType(), sliced);
     return matchSuccess();
   }
 };
@@ -1104,14 +1090,14 @@ class ConvertRangeOp : public OpRewritePattern {
       return matchFailure();
     }
 
-    auto iota = rewriter.create(op.getLoc(), result_type,
-                                                 rewriter.getI64IntegerAttr(0));
-    auto scaled = rewriter.create(
+    auto iota = rewriter.create(op.getLoc(), result_type,
+                                        rewriter.getI64IntegerAttr(0));
+    auto scaled = rewriter.create(
         op.getLoc(), result_type, iota, op.delta(),
-        getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
-    rewriter.replaceOpWithNewOp(
+        xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
+    rewriter.replaceOpWithNewOp(
         op, result_type, scaled, op.start(),
-        getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
+        xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
     return matchSuccess();
   }
 };
@@ -1159,13 +1145,13 @@ class GenericConvertReductionOp : public OpRewritePattern {
     // repeated arithmetic operations.
     Type reduce_element_type =
         is_accumulation ? GetAccumulationType(element_type) : element_type;
-    auto casted_input = rewriter.create(
-        loc, op.input(), reduce_element_type);
+    auto casted_input =
+        rewriter.create(loc, op.input(), reduce_element_type);
 
     // Each reduction op can have a different initial value.
     Value *init = Derived::GetInitialValue(reduce_element_type, loc, rewriter);
 
-    auto reduction = rewriter.create(
+    auto reduction = rewriter.create(
         loc, casted_input.getResult(), init,
         GetI64ElementsAttr(xla_dimensions, &rewriter));
     BuildReduceBody(reduce_element_type, &reduction.body(),
@@ -1186,16 +1172,16 @@ class GenericConvertReductionOp : public OpRewritePattern {
       auto divisor =
           GetScalarOfType(reduce_element_type, loc, divisor_count, &rewriter);
       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
-      result = rewriter.create(loc, result, divisor.getResult(),
-                                               broadcast_dims);
+      result = rewriter.create(loc, result, divisor.getResult(),
+                                      broadcast_dims);
     }
 
-    result = rewriter.create(loc, result, element_type);
+    result = rewriter.create(loc, result, element_type);
 
     // Need to reshape back after the reduction if we're keeping the reduced
     // dimensions.
     if (op.keep_dims()) {
-      result = rewriter.create(loc, op.getType(), result);
+      result = rewriter.create(loc, op.getType(), result);
     }
     rewriter.replaceOp(op, {result}, {op.reduction_indices()});
 
@@ -1211,8 +1197,7 @@ class GenericConvertReductionOp : public OpRewritePattern {
 //   %divisor = constant dense<...> : tensor
 //   %mean = "xla_hlo.div"(%sum, %divisor)
 class ConvertMeanOp
-    : public GenericConvertReductionOp {
+    : public GenericConvertReductionOp {
  public:
   using GenericConvertReductionOp::GenericConvertReductionOp;
 
@@ -1227,8 +1212,8 @@ class ConvertMeanOp
 //   %init = constant dense<...> : tensor
 //   %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"]
 //               {dimensions = ...}
-class ConvertSumOp : public GenericConvertReductionOp {
+class ConvertSumOp
+    : public GenericConvertReductionOp {
  public:
   using GenericConvertReductionOp::GenericConvertReductionOp;
 
@@ -1244,7 +1229,7 @@ class ConvertSumOp : public GenericConvertReductionOp {
  public:
   using GenericConvertReductionOp::GenericConvertReductionOp;
@@ -1304,7 +1289,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern {
     IntegerAttr iota_dimension =
         IntegerAttr::get(rewriter.getIntegerType(64), axis);
     Value *index_values =
-        rewriter.create(loc, index_type, iota_dimension);
+        rewriter.create(loc, index_type, iota_dimension);
 
     std::vector dimensions = input_type.getShape();
     dimensions.erase(dimensions.begin() + axis);
@@ -1315,7 +1300,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern {
     DenseIntElementsAttr reduction_dimensions =
         GetI64ElementsAttr({axis}, &rewriter);
 
-    auto reduction = rewriter.create(
+    auto reduction = rewriter.create(
         loc, llvm::ArrayRef(operands),
         llvm::ArrayRef(init_values), reduction_dimensions);
     StringRef direction = Derived::GetDirection();
@@ -1403,12 +1388,12 @@ class ConvertTileOp : public OpRewritePattern {
         RankedTensorType::get(broadcasted_shape, element_type);
     Type output_type = op.getType();
 
-    Value *result = rewriter.create(
+    Value *result = rewriter.create(
         loc, broadcasted_type, op.input(),
         GetI64ElementsAttr(broadcast_dimensions, &rewriter));
 
     if (output_type != broadcasted_type) {
-      result = rewriter.create(loc, output_type, result);
+      result = rewriter.create(loc, output_type, result);
     }
 
     rewriter.replaceOp(op, {result}, {op.multiples()});
@@ -1431,13 +1416,13 @@ class ConvertMaxPoolGradOp : public OpRewritePattern {
     Type element_type =
         op.orig_input()->getType().cast().getElementType();
 
-    auto result = rewriter.create(
+    auto result = rewriter.create(
         loc, op.getType(), op.orig_input(), op.grad(),
         GetScalarOfType(element_type, loc, 0, &rewriter),
         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
         nullptr);
 
-    BuildReduceBody(element_type, &result.scatter(), &rewriter);
+    BuildReduceBody(element_type, &result.scatter(), &rewriter);
     {
       OpBuilder::InsertionGuard guard(rewriter);
       Block *block = rewriter.createBlock(&result.select());
@@ -1446,11 +1431,11 @@ class ConvertMaxPoolGradOp : public OpRewritePattern {
       Type type = RankedTensorType::get(/*shape=*/{}, element_type);
       block->addArguments({type, type});
 
-      auto reducer = rewriter.create(
+      auto reducer = rewriter.create(
           loc, block->getArgument(0), block->getArgument(1),
           /*broadcast_dimensions=*/nullptr,
           StringAttr::get("GE", rewriter.getContext()));
-      rewriter.create(loc, reducer.getResult());
+      rewriter.create(loc, reducer.getResult());
     }
 
     rewriter.replaceOp(op, {result}, {op.orig_output()});
@@ -1533,7 +1518,7 @@ class ConvertConv2DBackpropInputOp
       return matchFailure();
     }
 
-    // Compute xla_hlo::ConvDimensionNumbers, dilation, and padding.
+    // Compute ConvDimensionNumbers, dilation, and padding.
     SmallVector kernel_spatial_dims(num_spatial_dims);
     SmallVector conv_paddings(num_spatial_dims * 2);
     SmallVector lhs_dilation(num_spatial_dims);
@@ -1550,7 +1535,7 @@ class ConvertConv2DBackpropInputOp
       lhs_dilation[i] = dims.spatial_dims[i].stride;
       rhs_dilation[i] = dilations[dim];
     }
-    RankedTensorType paddings_ty = mlir::RankedTensorType::get(
+    RankedTensorType paddings_ty = RankedTensorType::get(
         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_paddings);
     auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter);
@@ -1567,17 +1552,17 @@ class ConvertConv2DBackpropInputOp
     }
 
     // Mirror the filter in the spatial dimensions.
-    filter = rewriter.create(
+    filter = rewriter.create(
         loc, filter, GetI64ElementsAttr(kernel_spatial_dims, &rewriter));
 
     // activation gradients
     //   = gradients (with padding and dilation)  mirrored_weights
-    Value *result = rewriter.create(
+    Value *result = rewriter.create(
         loc, op.getType(), op.out_backprop(), filter,
         /*window_strides=*/GetI64ElementsAttr(ones, &rewriter),
         /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
         GetI64ElementsAttr(rhs_dilation, &rewriter),
-        xla_hlo::ConvDimensionNumbers::get(
+        ConvDimensionNumbers::get(
             /*input_batch_dimension=*/batch_dim_attr,
             /*input_feature_dimension=*/feature_dim_attr,
             /*input_spatial_dimensions=*/spatial_dims_attr,
@@ -1689,7 +1674,7 @@ class ConvertConv2DBackpropFilterOp
       return matchFailure();
     }
 
-    // Compute xla_hlo::ConvDimensionNumbers, dilation, and padding.
+    // Compute ConvDimensionNumbers, dilation, and padding.
     SmallVector conv_padding(num_spatial_dims * 2);
     SmallVector rhs_dilation(num_spatial_dims);
     SmallVector window_strides(num_spatial_dims);
@@ -1761,7 +1746,7 @@ class ConvertConv2DBackpropFilterOp
       conv_padding[i * 2 + 1] = pad_total - pad_before;
     }
 
-    RankedTensorType paddings_ty = mlir::RankedTensorType::get(
+    RankedTensorType paddings_ty = RankedTensorType::get(
         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_padding);
     auto out_spatial_dims_attr =
@@ -1773,12 +1758,12 @@ class ConvertConv2DBackpropFilterOp
     auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
 
     Location loc = op.getLoc();
-    Value *result = rewriter.create(
+    Value *result = rewriter.create(
         loc, op.getType(), op.input(), op.out_backprop(),
         /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
         /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
         GetI64ElementsAttr(rhs_dilation, &rewriter),
-        xla_hlo::ConvDimensionNumbers::get(
+        ConvDimensionNumbers::get(
             // Swap batch_dim and feature_dim in the activations.
             /*input_batch_dimension=*/feature_dim_attr,
             /*input_feature_dimension=*/batch_dim_attr,
@@ -1836,21 +1821,21 @@ class ConvertOneHotOp : public OpRewritePattern {
 
     Location loc = op.getLoc();
     auto index_type = RankedTensorType::get(output_dims, element_type);
-    Value *compare = rewriter.create(
+    Value *compare = rewriter.create(
         loc, op.indices(),
-        rewriter.create(
+        rewriter.create(
             loc, index_type,
             IntegerAttr::get(rewriter.getIntegerType(64), axis)),
         GetI64ElementsAttr(broadcast_dims, &rewriter),
         StringAttr::get("EQ", rewriter.getContext()));
-    Value *on_value = rewriter.create(
+    Value *on_value = rewriter.create(
         loc, op.getType(), op.on_value(),
         GetI64ElementsAttr(output_dims, &rewriter));
-    Value *off_value = rewriter.create(
+    Value *off_value = rewriter.create(
         loc, op.getType(), op.off_value(),
         GetI64ElementsAttr(output_dims, &rewriter));
-    Value *result = rewriter.create(
-        loc, op.getType(), compare, on_value, off_value);
+    Value *result = rewriter.create(loc, op.getType(), compare,
+                                              on_value, off_value);
 
     rewriter.replaceOp(
         op, {result},
@@ -1861,41 +1846,35 @@ class ConvertOneHotOp : public OpRewritePattern {
 };
 
 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
-}  // end anonymous namespace
-}  // end namespace xla
-}  // end namespace mlir
 
-LogicalResult mlir::xla_hlo::legalizeTF(Operation *op,
-                                        bool allow_partial_conversion) {
+LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
   MLIRContext *context = op->getContext();
 
   // Add lowering patterns to the list.
   OwningRewritePatternList patterns;
-  xla::populateWithGenerated(context, &patterns);
+  populateWithGenerated(context, &patterns);
 
   // Add patterns that lower some of the high level TensorFlow ops to lower
   // level TensorFlow ops. So, we don't have to target all the TensorFlow ops
   // here for lowering to HLO.
-  mlir::TF::PopulateLoweringTFPatterns(context, &patterns);
-  patterns.insert,
-                  mlir::xla::ConvertSoftmaxOp,
-                  mlir::xla::ConvertSplitOp, mlir::xla::ConvertStridedSliceOp,
-                  mlir::xla::ConvertMeanOp, mlir::xla::ConvertSumOp,
-                  mlir::xla::ConvertMaxOp, mlir::xla::ConvertTileOp,
-                  mlir::xla::ConvertMaxPoolGradOp, mlir::xla::ConvertOneHotOp,
-                  mlir::xla::ConvertConv2DBackpropInputOp,
-                  mlir::xla::ConvertConv2DBackpropFilterOp>(op->getContext());
+  TF::PopulateLoweringTFPatterns(context, &patterns);
+  patterns
+      .insert,
+              ConvertSoftmaxOp, ConvertSplitOp,
+              ConvertStridedSliceOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp,
+              ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp,
+              ConvertConv2DBackpropInputOp, ConvertConv2DBackpropFilterOp>(
+          op->getContext());
 
   ConversionTarget target(*context);
   target.addLegalDialect();
 
   if (!allow_partial_conversion) {
-    target.addLegalOp();
+    // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
+    target.addLegalOp();
     return applyFullConversion(op, target, patterns);
   }
 
@@ -1904,10 +1883,19 @@ LogicalResult mlir::xla_hlo::legalizeTF(Operation *op,
 
 /// Performs the lowering to XLA dialect.
 void LegalizeTF::runOnFunction() {
-  if (failed(
-          mlir::xla_hlo::legalizeTF(getFunction(), allow_partial_conversion_)))
+  if (failed(legalizeTF(getFunction(), allow_partial_conversion_)))
     signalPassFailure();
 }
 
 static PassRegistration pass(
     "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect");
+
+}  // end namespace
+
+std::unique_ptr> createLegalizeTFPass(
+    bool allow_partial_conversion) {
+  return std::make_unique(allow_partial_conversion);
+}
+
+}  // end namespace xla_hlo
+}  // end namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index ef11acab481..24d24e864d9 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -237,7 +237,7 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)),
 //===----------------------------------------------------------------------===//
 
 def CastElementsToI64Elements : NativeCodeCall<
-  "::xla::ConvertElementsAttr("
+  "xla::ConvertElementsAttr("
     "$0, $_builder.getIntegerType(64)).cast()">;
 
 def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)),

From 6b9f9af1ff1de5fab07acb4095c3ceb19ef144fd Mon Sep 17 00:00:00 2001
From: Ruoxin Sang 
Date: Mon, 2 Dec 2019 15:23:12 -0800
Subject: [PATCH 184/279] Throw an explicit error if user call TPUStrategy
 experimental_run_v2 in eager mode with a python function.

PiperOrigin-RevId: 283429576
Change-Id: I7b12b95211eb2555fa579819ec68a3cea870cf06
---
 .../distribute/custom_training_loop_test.py   |  5 +-
 tensorflow/python/distribute/tpu_strategy.py  | 60 ++++++++++++--
 tensorflow/python/distribute/values_test.py   | 79 ++++++++++++++++---
 3 files changed, 123 insertions(+), 21 deletions(-)

diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py
index 1db9bff21f0..e9b283d376c 100644
--- a/tensorflow/python/distribute/custom_training_loop_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_test.py
@@ -36,9 +36,8 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
 
   @combinations.generate(
       combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
+          distribution=strategy_combinations.strategies_minus_tpu,
+          mode=["eager"]))
   def testFullEager(self, distribution):
     dataset = self._get_dataset()
 
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 2dd4309537a..8f32e8e2226 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -37,6 +37,7 @@ from tensorflow.python.distribute import values
 from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
+from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import device_spec
 from tensorflow.python.framework import dtypes
@@ -82,6 +83,29 @@ def maybe_init_scope():
       yield
 
 
+def validate_experimental_run_function(fn):
+  """Validate the function passed into strategy.experimental_run_v2."""
+
+  # We allow three types of functions/objects passed into TPUStrategy
+  # experimental_run_v2 in eager mode:
+  #   1. a user annotated tf.function
+  #   2. a ConcreteFunction, this is mostly what you get from loading a saved
+  #      model.
+  #   3. a callable object and the `__call__` method itself is a tf.function.
+  #
+  # Otherwise we return an error, because we don't support eagerly running
+  # experimental_run_v2 in TPUStrategy.
+
+  if context.executing_eagerly() and not isinstance(
+      fn, def_function.Function) and not isinstance(
+          fn, function.ConcreteFunction) and not (callable(fn) and isinstance(
+              fn.__call__, def_function.Function)):
+    raise NotImplementedError(
+        "TPUStrategy.experimental_run_v2(fn, ...) does not support eager "
+        "execution. Either convert `fn` into a tf.function or consider "
+        "calling strategy.experimental_run_v2 inside a tf.function.")
+
+
 @tf_export("distribute.experimental.TPUStrategy", v1=[])
 class TPUStrategy(distribute_lib.Strategy):
   """TPU distribution strategy implementation."""
@@ -89,14 +113,36 @@ class TPUStrategy(distribute_lib.Strategy):
   def __init__(self,
                tpu_cluster_resolver=None,
                device_assignment=None):
-    """Initializes the TPUStrategy object.
+    """Synchronous training in TPU donuts or Pods.
+
+    To construct a TPUStrategy object, you need to run the
+    initialization code as below:
+
+    ```python
+    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
+    tf.config.experimental_connect_to_cluster(resolver)
+    tf.tpu.experimental.initialize_tpu_system(resolver)
+    strategy = tf.distribute.experimental.TPUStrategy(resolver)
+    ```
+
+    While using distribution strategies, the variables created within strategy's
+    scope will be replicated across all the replicas and can be kept in sync
+    using all-reduce algorithms.
+
+    To run TF2 programs on TPUs, you can either use `.compile` and
+    `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
+    training loop by calling `strategy.experimental_run_v2` directly. Note that
+    TPUStrategy doesn't support pure eager execution, so please make sure the
+    function passed into `strategy.experimental_run_v2` is a `tf.function` or
+    `strategy.experimental_run_v2` us called inside a `tf.function` if running
+    in eager mode.
 
     Args:
       tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
-          which provides information about the TPU cluster.
+        which provides information about the TPU cluster.
       device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
-          specify the placement of replicas on the TPU cluster. Currently only
-          supports the usecase of using a single core within a TPU cluster.
+        specify the placement of replicas on the TPU cluster. Currently only
+        supports the usecase of using a single core within a TPU cluster.
     """
     super(TPUStrategy, self).__init__(TPUExtended(
         self, tpu_cluster_resolver, device_assignment=device_assignment))
@@ -111,6 +157,8 @@ class TPUStrategy(distribute_lib.Strategy):
   # This implementation runs a single step. It does not use infeed or outfeed.
   def experimental_run_v2(self, fn, args=(), kwargs=None):
     """See base class."""
+    validate_experimental_run_function(fn)
+
     # Note: the target function is converted to graph even when in Eager mode,
     # so autograph is on by default here.
     fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
@@ -157,6 +205,8 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
   # This implementation runs a single step. It does not use infeed or outfeed.
   def experimental_run_v2(self, fn, args=(), kwargs=None):
     """See base class."""
+    validate_experimental_run_function(fn)
+
     fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
     return self.extended.tpu_run(fn, args, kwargs)
 
@@ -699,7 +749,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
         ]
 
       # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
-      if result[0] is None:
+      if result[0] is None or isinstance(result[0], ops.Operation):
         replicate_outputs = [None] * len(replicate_outputs)
       else:
         replicate_outputs = [
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index d97d1155c82..26d0eb3ac32 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -818,13 +818,31 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
     self.assertEqual(2., self.evaluate(add1(replica_local)))
 
 
-@combinations.generate(
-    combinations.combine(
-        distribution=[
-            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-            strategy_combinations.tpu_strategy,
-        ],
-        mode=["graph", "eager"]))
+def mirrored_and_tpu_strategy_combinations():
+  return combinations.combine(
+      distribution=[
+          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          strategy_combinations.tpu_strategy,
+      ],
+      mode=["graph", "eager"])
+
+
+def strategy_and_run_tf_function_combinations():
+  # Test the combination of different strategies and whether a tf.function
+  # is passed into strategy.experimental_run_v2."""
+  return combinations.combine(
+      distribution=[
+          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+      ],
+      mode=["graph", "eager"],
+      experimental_run_tf_function=[True, False]) + combinations.combine(
+          distribution=[
+              strategy_combinations.tpu_strategy,
+          ],
+          mode=["graph", "eager"],
+          experimental_run_tf_function=[True])
+
+
 class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
 
   def _assign_replica_local(self, v, new):
@@ -842,6 +860,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
     save_path, _ = self._save_return_saver(sess, var)
     return save_path
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
     with self.cached_session() as sess:
       v, replica_local = _make_replica_local(
@@ -862,6 +881,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         saver.restore(sess, save_path)
         self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
     if context.num_gpus() < 1 and context.executing_eagerly():
       self.skipTest("A GPU is not available for this test in eager mode.")
@@ -978,36 +998,46 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         saver.restore(sess, save_path)
         self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
     save_path = self._save_replica_local_mean(distribution)
     self._restore_replica_local_mean(save_path, distribution)
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
     save_path = self._save_replica_local_sum(distribution)
     self._restore_replica_local_sum(save_path, distribution)
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
     save_path = self._save_replica_local_mean(distribution)
     self._restore_normal(save_path)
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveReplicaLocalSumRestoreNormal(self, distribution):
     save_path = self._save_replica_local_sum(distribution)
     self._restore_normal(save_path)
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveNormalRestoreReplicaLocalMean(self, distribution):
     save_path = self._save_normal()
     self._restore_replica_local_mean(save_path, distribution)
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveNormalRestoreReplicaLocalSum(self, distribution):
     save_path = self._save_normal()
     self._restore_replica_local_sum(save_path, distribution)
 
-  def testAssign(self, distribution):
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testAssign(self, distribution, experimental_run_tf_function):
+
     def assign(fn, v, update_value, cross_replica):
       update_fn = lambda: getattr(v, fn)(update_value)
       if cross_replica:
         return update_fn()
       else:
+        if experimental_run_tf_function:
+          update_fn = def_function.function(update_fn)
         return distribution.experimental_local_results(
             distribution.experimental_run_v2(update_fn))
     updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
@@ -1033,12 +1063,17 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         self.assertAllEqual(self.evaluate(component.read_value()),
                             self.evaluate(array_ops.ones_like(component)))
 
-  def testAssignDtypeConversion(self, distribution):
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testAssignDtypeConversion(self, distribution,
+                                experimental_run_tf_function):
+
     def assign(fn, v, update_value, cross_replica):
       update_fn = lambda: getattr(v, fn)(update_value)
       if cross_replica:
         return update_fn()
       else:
+        if experimental_run_tf_function:
+          update_fn = def_function.function(update_fn)
         return distribution.experimental_local_results(
             distribution.experimental_run_v2(update_fn))
     updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
@@ -1064,6 +1099,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         self.assertAllEqual(self.evaluate(component.read_value()),
                             self.evaluate(array_ops.ones_like(component)))
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testAssignWithAggregationSum(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
@@ -1076,6 +1112,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
       self.assertAllEqual(self.evaluate(component.read_value()),
                           self.evaluate(array_ops.ones_like(component)))
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testAssignAddSubWithAggregationSum(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
@@ -1090,7 +1127,9 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         ValueError, "SyncOnReadVariable does not support "):
       self.evaluate(v.assign_sub(1.))
 
-  def testReadValueInReplicaContext(self, distribution):
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testReadValueInReplicaContext(self, distribution,
+                                    experimental_run_tf_function):
     aggregations = [
         variables_lib.VariableAggregation.NONE,
         variables_lib.VariableAggregation.SUM,
@@ -1104,12 +1143,19 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
             synchronization=variables_lib.VariableSynchronization.ON_READ,
             aggregation=aggregation)
       self.evaluate(variables_lib.global_variables_initializer())
-      results = self.evaluate(distribution.experimental_local_results(
-          distribution.experimental_run_v2(v.read_value)))
+      if experimental_run_tf_function:
+        read_var_fn = def_function.function(v.read_value)
+      else:
+        read_var_fn = v.read_value
+      results = self.evaluate(
+          distribution.experimental_local_results(
+              distribution.experimental_run_v2(read_var_fn)))
       for component, value in zip(v._values, results):
         self.assertAllEqual(self.evaluate(component.read_value()), value)
 
-  def testReadValueInCrossReplicaContext(self, distribution):
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testReadValueInCrossReplicaContext(self, distribution,
+                                         experimental_run_tf_function):
     aggregations = [
         variables_lib.VariableAggregation.SUM,
         variables_lib.VariableAggregation.MEAN,
@@ -1125,10 +1171,15 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
             synchronization=variables_lib.VariableSynchronization.ON_READ,
             aggregation=aggregation)
       self.evaluate(variables_lib.global_variables_initializer())
+
       def assign(v=v):
         ctx = distribution_strategy_context.get_replica_context()
         replica_id = ctx.replica_id_in_sync_group
         return v.assign(math_ops.cast(replica_id, dtypes.float32))
+
+      if experimental_run_tf_function:
+        assign = def_function.function(assign)
+
       self.evaluate(distribution.experimental_local_results(
           distribution.experimental_run_v2(assign)))
       result = self.evaluate(v.read_value())
@@ -1142,6 +1193,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         expected = 0
       self.assertEqual(expected, result, aggregation)
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
@@ -1153,6 +1205,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
       self.evaluate(v.read_value())
 
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testInitializedToSameValueInsideEagerRun(self, distribution):
     if not context.executing_eagerly(): self.skipTest("eager only")
 

From 882a1c8ed7dfeeacc44c180af38c1ce5635861be Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Mon, 2 Dec 2019 15:41:54 -0800
Subject: [PATCH 185/279] Remove dependence on core/lib/strings/base64.h in
 tensorflow/core/platform.

The library has been moved to core/platform.

PiperOrigin-RevId: 283433422
Change-Id: I98a272e65b7f32dfec181a564172652a4343bf59
---
 tensorflow/core/platform/cloud/BUILD                   | 3 +++
 tensorflow/core/platform/cloud/google_auth_provider.cc | 2 +-
 tensorflow/core/platform/cloud/oauth_client.cc         | 2 +-
 tensorflow/core/platform/cloud/oauth_client_test.cc    | 2 +-
 4 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 1ad3d06f5bb..e38c51974fb 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -204,6 +204,7 @@ cc_library(
         ":retrying_utils",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:base64",
         "//tensorflow/core/platform:errors",
         "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:status",
@@ -275,6 +276,7 @@ cc_library(
         ":http_request",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:base64",
         "//tensorflow/core/platform:errors",
         "//tensorflow/core/platform:status",
         "@boringssl//:crypto",
@@ -423,6 +425,7 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:base64",
         "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:scanner",
         "@boringssl//:crypto",
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc
index 264cb041f77..b8d2acd83ff 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider.cc
@@ -25,7 +25,7 @@ limitations under the License.
 
 #include "absl/strings/match.h"
 #include "include/json/json.h"
-#include "tensorflow/core/lib/strings/base64.h"
+#include "tensorflow/core/platform/base64.h"
 #include "tensorflow/core/platform/cloud/retrying_utils.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index 69ba1f0926e..bd4b3ae0b5c 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -27,7 +27,7 @@ limitations under the License.
 #include 
 #include 
 #include 
-#include "tensorflow/core/lib/strings/base64.h"
+#include "tensorflow/core/platform/base64.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 890e75a7036..8dfff63873f 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
 #include 
 #include 
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/base64.h"
+#include "tensorflow/core/platform/base64.h"
 #include "tensorflow/core/platform/cloud/http_request_fake.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/path.h"

From 6a0e85b2da6cc823fe43cdc735ad56012f1e8005 Mon Sep 17 00:00:00 2001
From: Advait Jain 
Date: Mon, 2 Dec 2019 15:54:42 -0800
Subject: [PATCH 186/279] Put all arduino-cli build artifacts into
 tflite-arduino-build directory.

PiperOrigin-RevId: 283435814
Change-Id: I8407a59663c6ff686892b56f251b3d4e47668a1a
---
 .../micro/tools/ci_build/test_arduino_library.sh           | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh b/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
index 8911f0d4274..04a5a617655 100755
--- a/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
+++ b/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
@@ -23,7 +23,6 @@ set -e
 ARDUINO_HOME_DIR=${HOME}/Arduino
 ARDUINO_LIBRARIES_DIR=${ARDUINO_HOME_DIR}/libraries
 ARDUINO_CLI_TOOL=/tmp/arduino-cli
-# Necessary due to bug in arduino-cli that allows it to build files in pwd
 TEMP_BUILD_DIR=/tmp/tflite-arduino-build
 
 LIBRARY_ZIP=${1}
@@ -57,11 +56,9 @@ InstallLibraryDependencies () {
 
 InstallLibraryDependencies
 
-# Change into this dir before running the tests
-cd ${TEMP_BUILD_DIR}
-
 for f in ${ARDUINO_LIBRARIES_DIR}/tensorflow_lite/examples/*/*.ino; do
-  ${ARDUINO_CLI_TOOL} compile --fqbn arduino:mbed:nano33ble $f
+  ${ARDUINO_CLI_TOOL} compile --build-cache-path ${TEMP_BUILD_DIR} --build-path ${TEMP_BUILD_DIR} --fqbn arduino:mbed:nano33ble $f
 done
 
 rm -rf ${ARDUINO_LIBRARIES_DIR}
+rm -rf ${TEMP_BUILD_DIR}

From 71681bd6914fed82dabfca3518772b141b1b0cd7 Mon Sep 17 00:00:00 2001
From: Jose Baiocchi 
Date: Mon, 2 Dec 2019 16:09:38 -0800
Subject: [PATCH 187/279] Only save per-thread events when non-empty

PiperOrigin-RevId: 283439027
Change-Id: I33118a45a69fff30b3c29d927c8a8380536dd1af
---
 tensorflow/core/profiler/internal/traceme_recorder.cc | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/profiler/internal/traceme_recorder.cc b/tensorflow/core/profiler/internal/traceme_recorder.cc
index d191a49fc94..3257a347d66 100644
--- a/tensorflow/core/profiler/internal/traceme_recorder.cc
+++ b/tensorflow/core/profiler/internal/traceme_recorder.cc
@@ -199,7 +199,9 @@ void TraceMeRecorder::RegisterThread(int32 tid, ThreadLocalRecorder* thread) {
 void TraceMeRecorder::UnregisterThread(TraceMeRecorder::ThreadEvents&& events) {
   mutex_lock lock(mutex_);
   threads_.erase(events.thread.tid);
-  orphaned_events_.push_back(std::move(events));
+  if (!events.events.empty()) {
+    orphaned_events_.push_back(std::move(events));
+  }
 }
 
 // This method is performance critical and should be kept fast. It is called
@@ -211,7 +213,10 @@ TraceMeRecorder::Events TraceMeRecorder::Clear() {
   std::swap(orphaned_events_, result);
   for (const auto& entry : threads_) {
     auto* recorder = entry.second;
-    result.push_back(recorder->Clear());
+    TraceMeRecorder::ThreadEvents events = recorder->Clear();
+    if (!events.events.empty()) {
+      result.push_back(std::move(events));
+    }
   }
   return result;
 }

From a46fa0b4059a159cfeccc31c1d825bdbeacea70c Mon Sep 17 00:00:00 2001
From: Amit Patankar 
Date: Mon, 2 Dec 2019 16:12:40 -0800
Subject: [PATCH 188/279] Export the bfloat16 classes and functions from C++ to
 Python with pybind11 instead of swig. This is part of a larger effort to
 deprecate swig and eventually with modularization break pywrap_tensorflow
 into smaller components. It will also make exporting C++ ops to Python
 significantly easier. XLA is using the pybind11 macros already. Please refer
 to
 https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md
 for more information.

PiperOrigin-RevId: 283439638
Change-Id: I8ca8e5e4835995f78b8b1d78036a98de444508d3
---
 tensorflow/python/BUILD                        | 14 +++++++++++++-
 tensorflow/python/framework/dtypes.py          | 11 ++++++-----
 tensorflow/python/lib/core/bfloat16.cc         |  4 +++-
 tensorflow/python/lib/core/bfloat16_test.py    |  4 ++--
 .../core/{bfloat16.i => bfloat16_wrapper.cc}   | 18 ++++++------------
 tensorflow/python/tensorflow.i                 |  2 --
 .../python/training/moving_averages_test.py    |  5 ++---
 7 files changed, 32 insertions(+), 26 deletions(-)
 rename tensorflow/python/lib/core/{bfloat16.i => bfloat16_wrapper.cc} (70%)

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 613f20e097c..4518aeca3bf 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -400,6 +400,17 @@ cc_library(
     ],
 )
 
+tf_python_pybind_extension(
+    name = "_pywrap_bfloat16",
+    srcs = ["lib/core/bfloat16_wrapper.cc"],
+    hdrs = ["lib/core/bfloat16.h"],
+    module_name = "_pywrap_bfloat16",
+    deps = [
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
 cc_library(
     name = "ndarray_tensor_bridge",
     srcs = ["lib/core/ndarray_tensor_bridge.cc"],
@@ -1158,6 +1169,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":_dtypes",
+        ":_pywrap_bfloat16",
         ":pywrap_tensorflow",
         "//tensorflow/core:protos_all_py",
     ],
@@ -5442,7 +5454,6 @@ tf_py_wrap_cc(
         "grappler/cost_analyzer.i",
         "grappler/item.i",
         "grappler/tf_optimizer.i",
-        "lib/core/bfloat16.i",
         "lib/core/strings.i",
         "lib/io/file_io.i",
         "lib/io/py_record_reader.i",
@@ -5528,6 +5539,7 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
     ":numpy_lib",  # checkpoint_reader
     ":safe_ptr",  # checkpoint_reader
     ":python_op_gen",  # python_op_gen
+    ":bfloat16_lib",  # bfloat16
     "//tensorflow/core/util/tensor_bundle",  # checkpoint_reader
 ]
 
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 6bcf71915c7..828f30c40eb 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -21,14 +21,15 @@ import numpy as np
 from six.moves import builtins
 
 from tensorflow.core.framework import types_pb2
-# pywrap_tensorflow must be imported prior to _dtypes for the MacOS linker
-# to resolve the protobufs properly.
-# pylint: disable=unused-import,g-bad-import-order
-from tensorflow.python import pywrap_tensorflow
+# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid
+# protobuf errors where a file is defined twice on MacOS.
+# pylint: disable=invalid-import-order,g-bad-import-order
+from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
+from tensorflow.python import _pywrap_bfloat16
 from tensorflow.python import _dtypes
 from tensorflow.python.util.tf_export import tf_export
 
-_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
+_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
 
 
 # pylint: disable=slots-on-old-class
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index 54be76375c9..42b248a7ddb 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -532,7 +532,9 @@ struct Bfloat16GeFunctor {
 
 // Initializes the module.
 bool Initialize() {
-  // It's critical to import umath to avoid crash in open source build.
+  // It's critical to ImportNumpy and import umath
+  // to avoid crash in open source build.
+  ImportNumpy();
   import_umath1(false);
 
   Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index bc928cd9e5e..32453ae2296 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -24,12 +24,12 @@ import math
 import numpy as np
 
 # pylint: disable=unused-import,g-bad-import-order
-from tensorflow.python import pywrap_tensorflow
+from tensorflow.python import _pywrap_bfloat16
 from tensorflow.python.framework import dtypes
 from tensorflow.python.platform import test
 
 
-bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
+bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
 
 
 class Bfloat16Test(test.TestCase):
diff --git a/tensorflow/python/lib/core/bfloat16.i b/tensorflow/python/lib/core/bfloat16_wrapper.cc
similarity index 70%
rename from tensorflow/python/lib/core/bfloat16.i
rename to tensorflow/python/lib/core/bfloat16_wrapper.cc
index 10444b676b2..4a8e180c154 100644
--- a/tensorflow/python/lib/core/bfloat16.i
+++ b/tensorflow/python/lib/core/bfloat16_wrapper.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,18 +13,12 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-%{
+#include "include/pybind11/pybind11.h"
 #include "tensorflow/python/lib/core/bfloat16.h"
-%}
 
-%init %{
-tensorflow::RegisterNumpyBfloat16();
-%}
+PYBIND11_MODULE(_pywrap_bfloat16, m) {
+  tensorflow::RegisterNumpyBfloat16();
 
-%{
-PyObject* TF_bfloat16_type() {
-  return tensorflow::Bfloat16PyType();
+  m.def("TF_bfloat16_type",
+        [] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
 }
-%}
-
-PyObject* TF_bfloat16_type();
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 413b5126e77..761e6f376f8 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -21,8 +21,6 @@ limitations under the License.
 
 %include "tensorflow/python/client/tf_session.i"
 
-%include "tensorflow/python/lib/core/bfloat16.i"
-
 %include "tensorflow/python/lib/io/file_io.i"
 
 %include "tensorflow/python/lib/io/py_record_reader.i"
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 3a52d7653f4..1aa8947fb1f 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -131,7 +130,6 @@ class MovingAveragesTest(test.TestCase):
 
   @test_util.deprecated_graph_mode_only
   def testWeightedMovingAverageBfloat16(self):
-    bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
     with self.cached_session() as sess:
       decay = 0.5
       weight = array_ops.placeholder(dtypes.bfloat16, [])
@@ -154,7 +152,8 @@ class MovingAveragesTest(test.TestCase):
       wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2})
       numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay)
       denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
-      self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array)
+      self.assertAllClose(
+          dtypes._np_bfloat16(numerator_2 / denominator_2), wma_array)
 
 
 def _Repeat(value, dim):

From 513f16d55d31f2daf21972f910b0438fad3f151e Mon Sep 17 00:00:00 2001
From: Smit Hinsu 
Date: Mon, 2 Dec 2019 16:24:53 -0800
Subject: [PATCH 189/279] Lower TensorFlow Einsum op to HLO

Specifically,

* Verify TF Einsum op arity
* Define Einsum and UnaryEinsum op in HLO. UnaryEinsum is defined so that the op is not
  variadic or requires conversion from unary to binary at the time of import. This way
  translations are simplified and also it is easier to operand on the op.
* Convert TF Einsum op to HLO Einsum op or UnaryEinsum op.
* Canonicalize UnaryEinsum to Einsum op with two

Also, added support for the StringAttr in HLO exporter generator.

PiperOrigin-RevId: 283441951
Change-Id: I0d86ae2723209b08eb2c243a9dcde843dbd5897c
---
 .../mlir/tensorflow/ir/tf_generated_ops.td    |  4 ++
 .../compiler/mlir/tensorflow/ir/tf_ops.cc     | 15 +++++
 .../mlir/tensorflow/tests/tf-ops.mlir         |  8 +++
 tensorflow/compiler/mlir/xla/BUILD            |  1 +
 tensorflow/compiler/mlir/xla/ir/hlo_ops.cc    | 11 +++-
 tensorflow/compiler/mlir/xla/ir/hlo_ops.td    | 37 +++++++++++
 tensorflow/compiler/mlir/xla/ir/hlo_utils.cc  | 15 +++++
 tensorflow/compiler/mlir/xla/ir/hlo_utils.h   |  7 ++
 tensorflow/compiler/mlir/xla/ir/hlo_utils.td  |  7 +-
 .../compiler/mlir/xla/mlir_hlo_to_hlo.cc      | 11 ++++
 .../compiler/mlir/xla/operator_writer_gen.cc  | 11 ++--
 .../compiler/mlir/xla/tests/canonicalize.mlir |  8 +++
 .../compiler/mlir/xla/tests/legalize-tf.mlir  | 14 ++++
 .../mlir/xla/tests/translate/einsum.mlir      |  9 +++
 .../mlir/xla/transforms/canonicalize.td       | 12 +++-
 .../mlir/xla/transforms/legalize_tf.cc        | 64 ++++++++++++-------
 16 files changed, 202 insertions(+), 32 deletions(-)
 create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir

diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index b68634ba704..57b61461d02 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -1358,6 +1358,10 @@ Comparison with `numpy.einsum`:
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
   TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
+
+  let verifier = [{
+    return Verify(*this);
+  }];
 }
 
 def TF_EluOp : TF_Op<"Elu", [NoSideEffect, SameOperandsAndResultType]> {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index a58e20a9952..3b836a6188d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -674,6 +674,21 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert(context);
 }
 
+//===----------------------------------------------------------------------===//
+// EinsumOp
+//===----------------------------------------------------------------------===//
+
+// Verifies that,
+// * Arity of the op is at most two.
+//
+// TODO(hinsu): Verify einsum equation attribute.
+static LogicalResult Verify(EinsumOp op) {
+  if (op.N() > 2) {
+    return op.emitOpError("supports at most two operands");
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // EmptyTensorListOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index b8e7ba71198..1914ca177cc 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -1650,3 +1650,11 @@ func @testSplitSmallSplitDim(%input: tensor<4x8xf32>) {
   %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
   return
 }
+
+// -----
+
+func @testTernaryEinsum(%arg0: tensor<2x3xf32>){
+  // expected-error @+1 {{supports at most two operands}}
+  %0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>)
+  return
+}
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 3ed3fb6fc40..ac3475cebc4 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -404,6 +404,7 @@ cc_library(
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client/lib:matrix",
         "//tensorflow/compiler/xla/service:hlo",
         "@llvm//:support",
         "@local_config_mlir//:Analysis",
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 639c85c48b5..8fa33d19363 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -47,10 +47,10 @@ limitations under the License.
 #include "mlir/Transforms/InliningUtils.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
 
 namespace mlir {
 #include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
-
 namespace xla_hlo {
 
 Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
@@ -936,6 +936,15 @@ void TupleOp::build(Builder* builder, OperationState& result,
   build(builder, result, builder->getTupleType(types), values);
 }
 
+//===----------------------------------------------------------------------===//
+// UnaryEinsumOp
+//===----------------------------------------------------------------------===//
+
+void UnaryEinsumOp::getCanonicalizationPatterns(
+    OwningRewritePatternList& results, MLIRContext* context) {
+  results.insert(context);
+}
+
 //===----------------------------------------------------------------------===//
 // CompareOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index f036dec92b9..4fb85f9f6b3 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -730,6 +730,43 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral
   let results = (outs HLO_Tensor);
 }
 
+def BASE_EinsumOp {
+  string summary = "Einsum operator";
+
+  string description = [{
+    Returns a tensor whose elements are defined by equation, which is written
+    in a shorthand form inspired by the Einstein summation convention.
+  }];
+}
+
+def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> {
+  let arguments = (ins
+    HLO_Tensor:$lhs,
+    HLO_Tensor:$rhs,
+    StrAttr:$einsum_config
+  );
+
+  let results = (outs HLO_Tensor);
+
+  // TODO(hinsu): Canonicalize to lower this client side HLO op to server
+  // side HLO ops.
+}
+
+def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> {
+  let arguments = (ins
+    HLO_Tensor:$operand,
+    StrAttr:$einsum_config
+  );
+
+  let results = (outs HLO_Tensor);
+
+  let hasCanonicalizer = 1;
+
+  // UnarayEinsumOp is unconditionally canonicalized to the binary EinsumOp so
+  // the HLO converter shouldn't be invoked.
+  let hasCustomHLOConverter = 1;
+}
+
 def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
   let arguments = (ins
     HLO_Tensor:$operand,
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc
index 82b7032d542..7d3e2ca2384 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc
@@ -17,6 +17,8 @@ limitations under the License.
 
 #include 
 
+#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+
 namespace mlir {
 namespace xla {
 
@@ -51,5 +53,18 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x,
   return DenseIntElementsAttr::get(type, broadcastDimensions);
 }
 
+DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
+  RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
+
+  DenseElementsAttr attr;
+  if (auto float_ty = ty.dyn_cast()) {
+    APFloat value(float_ty.getFloatSemantics(), raw_value);
+    return DenseElementsAttr::get(scalar_ty, value);
+  }
+  auto int_ty = ty.cast();
+  APInt value(int_ty.getWidth(), static_cast(raw_value), true);
+  return DenseElementsAttr::get(scalar_ty, value);
+}
+
 }  // namespace xla
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
index d81abf6a0be..86c90b49f16 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
@@ -18,6 +18,7 @@ limitations under the License.
 
 #include "mlir/IR/Attributes.h"  // TF:local_config_mlir
 #include "mlir/IR/Builders.h"  // TF:local_config_mlir
+#include "mlir/IR/PatternMatch.h"  // TF:local_config_mlir
 #include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/IR/TypeUtilities.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
@@ -48,6 +49,12 @@ static ElementsAttr getSplat(Builder* b, Value* val, T constant) {
 
   return DenseElementsAttr::get(valType, elementAttr);
 }
+
+// Returns DenseElementsAttr of rank zero with the given element type and the
+// value.
+// Requires `ty` to be either FloatType of IntegerType.
+DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
+
 }  // namespace xla
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td
index 1a56d230d0d..97b29bf0851 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td
@@ -18,9 +18,7 @@ limitations under the License.
 #ifndef HLO_UTILS
 #define HLO_UTILS
 
-#ifndef OP_BASE
 include "mlir/IR/OpBase.td"
-#endif // OP_BASE
 
 def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
 
@@ -34,4 +32,9 @@ def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
 def BinBroadcastDimensions : NativeCodeCall<
     "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
 
+// Here, the element type can be any integer or float type. But, note that only
+// 32 bit integers are supported for the value.
+class GetScalarOfType : NativeCodeCall<
+  "xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
+
 #endif // HLO_UTILS
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 267fd3b21b4..f717c8199fd 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -33,6 +33,7 @@ limitations under the License.
 #include "mlir/IR/TypeUtilities.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
+#include "tensorflow/compiler/xla/client/lib/matrix.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/literal_util.h"
@@ -77,6 +78,10 @@ static double ConvertAPFloat(llvm::APFloat value) {
   return value.convertToDouble();
 }
 
+static absl::string_view ConvertStringRef(mlir::StringRef value) {
+  return {value.data(), value.size()};
+}
+
 static std::vector ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
   auto values = attr.getValues();
   return {values.begin(), values.end()};
@@ -632,6 +637,12 @@ LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
   return success();
 }
 
+LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
+  // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
+  // operands.
+  return failure();
+}
+
 LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
   xla::XlaComputation condition;
   xla::XlaComputation body;
diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
index 4a9555a256a..acc3c17baf5 100644
--- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
+++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
@@ -32,17 +32,20 @@ using llvm::raw_ostream;
 using llvm::RecordKeeper;
 using llvm::StringRef;
 using mlir::interleaveComma;
+using mlir::tblgen::Attribute;
 using mlir::tblgen::NamedAttribute;
 using mlir::tblgen::NamedTypeConstraint;
 using mlir::tblgen::Operator;
 
 static std::string GetDefaultAttrExport(
     const mlir::tblgen::NamedAttribute& named_attr) {
-  auto storage_type = named_attr.attr.getStorageType();
+  Attribute attr = named_attr.attr;
+  StringRef storage_type = attr.getStorageType();
   // For some attribute types we have a general conversion, so use that.
-  if (storage_type.endswith("IntegerAttr") ||
-      storage_type.endswith("FloatAttr")) {
-    return "Convert" + named_attr.attr.getReturnType().str();
+  if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") ||
+                             storage_type.endswith("FloatAttr") ||
+                             storage_type.endswith("StringAttr"))) {
+    return "Convert" + attr.getReturnType().str();
   }
   return "Convert_" + named_attr.name.str();
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
index e6d99b9e7d8..fa39b77918a 100644
--- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
@@ -48,3 +48,11 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex>
 }
+
+// CHECK-LABEL: @unary_einsum
+func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
+  // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor
+  // CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
+  %0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 3004f2276fe..d8093f1a39a 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -215,6 +215,20 @@ func @pow_dynamic(%arg0: tensor) -> tensor {
   return %0: tensor
 }
 
+// CHECK-LABEL: func @einsum
+func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
+  // CHECK:  xla_hlo.einsum
+  %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
+  return %0: tensor<2x4xf32>
+}
+
+// CHECK-LABEL: func @unary_einsum
+func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
+  // CHECK:  xla_hlo.unary_einsum
+  %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
+  return %0: tensor<2x2xf32>
+}
+
 // CHECK-LABEL: func @floordiv_broadcast_i32
 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
   // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0>
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir b/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir
new file mode 100644
index 00000000000..e703a5cb872
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir
@@ -0,0 +1,9 @@
+// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+// CHECK-LABEL: ENTRY
+func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
+  // Simple einsum is lowered to HLO dot op.
+  // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  %0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
+  return %0 : tensor<3x5xi32>
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td
index bc44117910b..37f6d7deaa3 100644
--- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td
+++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td
@@ -17,7 +17,7 @@ limitations under the License.
 
 include "mlir/IR/OpBase.td"
 include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
-
+include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
 
 //===----------------------------------------------------------------------===//
 // DynamicSlice op patterns.
@@ -37,3 +37,13 @@ def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input,
           (HLO_SliceOp $input, (CastIntElementsAttr $starting_indices),
            (BuildSliceLimits $starting_indices, $slice_sizes),
             (BuildSliceStrides $input))>;
+
+def UnaryToBianryEinsumEq : NativeCodeCall<
+  "$_builder.getStringAttr(\",\" + $0.getValue().str())">;
+
+// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
+// operand.
+def UnaryEinsumToEinsum : Pat<
+  (HLO_UnaryEinsumOp $operand, $equation),
+  (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
+                $operand, (UnaryToBianryEinsumEq $equation))>;
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index a156685f005..94e0ce35cb0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -23,6 +23,7 @@ limitations under the License.
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/StandardOps/Ops.h"  // TF:local_config_mlir
 #include "mlir/IR/Attributes.h"  // TF:local_config_mlir
 #include "mlir/IR/Diagnostics.h"  // TF:local_config_mlir
@@ -168,20 +169,9 @@ static ConstOp GetMinValueForType(Type ty, Location loc,
 
 // Returns int or float scalar DenseElementsAttr attribute with the given
 // element type and the value.
-static ConstOp GetScalarOfType(Type ty, Location loc, int64_t raw_value,
-                               PatternRewriter *rewriter) {
-  RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
-
-  DenseElementsAttr attr;
-  if (auto float_ty = ty.dyn_cast_or_null()) {
-    APFloat value(float_ty.getFloatSemantics(), raw_value);
-    attr = DenseElementsAttr::get(scalar_ty, value);
-  } else {
-    auto int_ty = ty.cast();
-    APInt value(int_ty.getWidth(), static_cast(raw_value), true);
-    attr = DenseElementsAttr::get(scalar_ty, value);
-  }
-  return rewriter->create(loc, attr);
+static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
+                                    PatternRewriter *rewriter) {
+  return rewriter->create(loc, xla::GetScalarOfType(ty, raw_value));
 }
 
 // Builds body for reduce op by using the using the template binary op as the
@@ -639,6 +629,31 @@ class ConvertBF16FloorDivOp : public OpRewritePattern {
   }
 };
 
+// Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
+// depending on arity of the op.
+class ConvertEinsumOp : public OpRewritePattern {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(TF::EinsumOp op,
+                                     PatternRewriter &rewriter) const override {
+    StringAttr equation = op.getAttrOfType("equation");
+    if (op.N() == 1) {
+      rewriter.replaceOpWithNewOp(
+          op, op.getType(), *op.inputs().begin(), equation);
+    } else if (op.N() == 2) {
+      auto inputs = llvm::to_vector<2>(op.inputs());
+      rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0],
+                                            inputs[1], equation);
+    } else {
+      // TensorFlow EinsumOp verifies that the number of operands are at most
+      // two.
+      return Pattern::matchFailure();
+    }
+    return Pattern::matchSuccess();
+  }
+};
+
 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
 // dimensions with max as the reduction function.
 //
@@ -847,8 +862,8 @@ class ConvertSizeOp : public OpRewritePattern {
     const int64_t rank = input_ty.getRank();
     auto result_type = op.getResult()->getType();
     Operation *size =
-        GetScalarOfType(result_type.cast().getElementType(),
-                        op.getLoc(), 1, &rewriter);
+        GetScalarConstOfType(result_type.cast().getElementType(),
+                             op.getLoc(), 1, &rewriter);
     for (int64_t i = 0; i < rank; ++i) {
       auto dim = rewriter.create(
           op.getLoc(), result_type, input,
@@ -1169,8 +1184,8 @@ class GenericConvertReductionOp : public OpRewritePattern {
           divisor_count *= input_shape[i];
         }
       }
-      auto divisor =
-          GetScalarOfType(reduce_element_type, loc, divisor_count, &rewriter);
+      auto divisor = GetScalarConstOfType(reduce_element_type, loc,
+                                          divisor_count, &rewriter);
       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
       result = rewriter.create(loc, result, divisor.getResult(),
                                       broadcast_dims);
@@ -1203,7 +1218,7 @@ class ConvertMeanOp
 
   static Value *GetInitialValue(Type reduce_element_type, Location loc,
                                 PatternRewriter &rewriter) {
-    return GetScalarOfType(reduce_element_type, loc, 0, &rewriter);
+    return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
   }
 };
 
@@ -1219,7 +1234,7 @@ class ConvertSumOp
 
   static Value *GetInitialValue(Type reduce_element_type, Location loc,
                                 PatternRewriter &rewriter) {
-    return GetScalarOfType(reduce_element_type, loc, 0, &rewriter);
+    return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
   }
 };
 
@@ -1274,7 +1289,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern {
 
     Type index_element_type = output_type.getElementType();
     Value *index_init_value =
-        GetScalarOfType(index_element_type, loc, 0, &rewriter);
+        GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
 
     RankedTensorType index_type =
         RankedTensorType::get(input_type.getShape(), index_element_type);
@@ -1418,7 +1433,7 @@ class ConvertMaxPoolGradOp : public OpRewritePattern {
 
     auto result = rewriter.create(
         loc, op.getType(), op.orig_input(), op.grad(),
-        GetScalarOfType(element_type, loc, 0, &rewriter),
+        GetScalarConstOfType(element_type, loc, 0, &rewriter),
         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
         nullptr);
 
@@ -1860,8 +1875,9 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
   TF::PopulateLoweringTFPatterns(context, &patterns);
   patterns
       .insert,
+              ConvertEinsumOp, ConvertMaxPoolOp, ConvertRangeOp,
+              ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp,
+              ConvertSigmoidOp, ConvertSoftmaxOp,
               ConvertSoftmaxOp, ConvertSplitOp,
               ConvertStridedSliceOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp,
               ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp,

From 250d9bc96b039ce903dc4e122f831b67b7358073 Mon Sep 17 00:00:00 2001
From: Gaurav Jain 
Date: Mon, 2 Dec 2019 16:43:53 -0800
Subject: [PATCH 190/279] Replace Notification with simple mutex

This avoids the overheads of the Notification object such as mutex
acquisition during destruction.

PiperOrigin-RevId: 283445581
Change-Id: Ic30ea13186096c23ec775eac13412c6ffe6c9a0a
---
 .../common_runtime/eager/tensor_handle.cc     | 46 ++++++++++---------
 .../core/common_runtime/eager/tensor_handle.h | 19 ++++----
 .../eager/tensor_handle_test.cc               |  1 -
 3 files changed, 33 insertions(+), 33 deletions(-)

diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 9bda0512b3d..717ec586eef 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -132,10 +132,9 @@ TensorHandle::TensorHandle(std::unique_ptr t,
       ctx_(ctx),
       is_remote_(false),
       is_async_(false),
+      is_ready_(true),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
-  // Notify immediately since this handle is already ready.
-  is_ready_notification_.Notify();
 }
 
 TensorHandle::TensorHandle(std::unique_ptr t,
@@ -152,11 +151,10 @@ TensorHandle::TensorHandle(std::unique_ptr t,
       ctx_(ctx),
       is_remote_(false),
       is_async_(false),
+      is_ready_(true),
       handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
-  // Notify immediately since this handle is already ready.
-  is_ready_notification_.Notify();
 }
 
 Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
@@ -185,12 +183,10 @@ TensorHandle::TensorHandle(std::unique_ptr t,
       ctx_(ctx),
       is_remote_(false),
       is_async_(async),
+      is_ready_(!async),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Async Local TensorHandle: " << this
            << " device: " << device_;
-  if (!async) {
-    is_ready_notification_.Notify();
-  }
 }
 
 #if !defined(IS_MOBILE_PLATFORM)
@@ -227,11 +223,10 @@ TensorHandle::TensorHandle(std::unique_ptr t,
       ctx_(ctx),
       is_remote_(true),
       is_async_(false),
+      is_ready_(true),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Remote TensorHandle: " << this
            << " device: " << device_;
-  // Notify immediately since this handle is already ready.
-  is_ready_notification_.Notify();
 }
 
 Status TensorHandle::CreateUnshapedRemoteHandle(
@@ -264,21 +259,29 @@ TensorHandle::TensorHandle(std::unique_ptr t,
       ctx_(ctx),
       is_remote_(true),
       is_async_(true),
+      is_ready_(false),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
            << " device: " << device_;
 }
 #endif
 
-bool TensorHandle::IsReady() {
-  return !is_async_ || is_ready_notification_.HasBeenNotified();
+bool TensorHandle::IsReady() const {
+  // Avoid mutex acquisition for local sync handles
+  if (!is_async_ && !is_remote_) {
+    return true;
+  }
+
+  tf_shared_lock l(mu_);
+  return is_ready_;
 }
 
 Status TensorHandle::WaitReady(const char* caller) {
   if (!IsReady()) {
     profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
                                profiler::TraceMeLevel::kInfo);
-    is_ready_notification_.WaitForNotification();
+    tf_shared_lock l(mu_);
+    mu_.Await(Condition(&is_ready_));
   }
   return is_poisoned_;
 }
@@ -537,8 +540,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
   }
 
   DCHECK(is_remote_) << "SeRemoteShape is only called on remote handles.";
-  DCHECK(!is_ready_notification_.HasBeenNotified())
-      << "SetRemoteShape is only called on non-ready handles.";
+  DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles.";
 
   UnshapedRemoteTensorHandleData* p =
       reinterpret_cast(
@@ -548,7 +550,8 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
       remote_op_id_, remote_output_num_, shape, remote_task_,
       remote_context_id_, ctx_);
   is_poisoned_ = Status::OK();
-  is_ready_notification_.Notify();
+  mutex_lock l(mu_);
+  is_ready_ = true;
 
   return Status::OK();
 }
@@ -556,7 +559,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
 
 Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
   DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
-  DCHECK(!is_async_ || !is_ready_notification_.HasBeenNotified())
+  DCHECK(!is_async_ || !IsReady())
       << "SetTensor is only called on non-ready handles.";
 
   DVLOG(3) << "SetTensor on TensorHandle: " << this;
@@ -568,21 +571,22 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
   tensor_handle_data_ = absl::make_unique(tensor);
   if (is_async_) {
     is_poisoned_ = Status::OK();
-    is_ready_notification_.Notify();
+    mutex_lock l(mu_);
+    is_ready_ = true;
   }
+
   return Status::OK();
 }
 
 void TensorHandle::Poison(Status status) {
-  DCHECK(!is_async_ || !is_ready_notification_.HasBeenNotified())
+  DCHECK(!is_async_ || !IsReady())
       << "Poison(status) can only be called on non-ready handle: " << this;
 
   DVLOG(3) << "Poison on TensorHandle: " << this;
 
   is_poisoned_ = status;
-  if (is_async_ || is_remote_) {
-    is_ready_notification_.Notify();
-  }
+  mutex_lock l(mu_);
+  is_ready_ = true;
 }
 
 Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index f61d3d27951..c32ec834071 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -167,8 +167,6 @@ class TensorHandle : public core::RefCounted {
   // on a non-ready tensor.
   void Poison(Status status);
 
-  bool IsReady();
-
   Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
                       tensorflow::Tensor* output);
 
@@ -207,6 +205,12 @@ class TensorHandle : public core::RefCounted {
       std::vector* result);
 
  private:
+  // The TensorHandleData can either represent a local or remote tensor handle.
+  // Further, it can be in a non-ready state. It would become ready with a call
+  // to either SetTensor or SetRemoteShape which replaces the underlying data
+  // with a ready version of the tensor handle data.
+  bool IsReady() const;
+
   // If the contents of the Tensor pointed to by this handle is yet to be
   // computed by a EagerNode, this function will block till that computation is
   // done and the handle is "ready".
@@ -232,9 +236,9 @@ class TensorHandle : public core::RefCounted {
   // backing the resource. Else resource_device_ is nullptr.
   tensorflow::Device* const resource_device_;
 
-#if !defined(IS_MOBILE_PLATFORM)
   mutable mutex mu_;
 
+#if !defined(IS_MOBILE_PLATFORM)
   // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
   // variable is ready, since we could get the shape locally without remote copy
   // then.
@@ -263,25 +267,18 @@ class TensorHandle : public core::RefCounted {
   // `ctx` object is not owned and should outlive this handle.
   EagerContext* const ctx_;
 
-  // Explanation for NOLINT below: absl has clang-tidy macro to rename
-  // 'tensorflow::Notification' to 'absl::Notification'. TF does not use
-  // absl::Notification in open source now, so we can't follow clang-tidy
-  tensorflow::Notification is_ready_notification_;  // NOLINT
   // Does not need synchronization because it can be accessed only after
   // WaitReady() has returned. At that point, is_poisoned_ is immutable.
   Status is_poisoned_;
   const bool is_remote_;
   const bool is_async_;
+  bool is_ready_ GUARDED_BY(mu_);
 
   // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
   // refers to a remote resource handle, we store data types and shapes for
   // the underlying resource.
   std::vector handle_dtypes_and_shapes_;
 
-  // The TensorHandleData can either represent a local or remote tensor handle.
-  // Further, it can be in a non-ready state. It would become ready with a call
-  // to either SetTensor or SetRemoteShape which replaces the underlying data
-  // with a ready version of the tensor handle data.
   // Does not need synchronization because it can be accessed only after
   // WaitReady() has returned. At that point, tensor_handle_data_ is immutable.
   std::unique_ptr tensor_handle_data_;
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
index d8217e85315..ea81cda6199 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
@@ -39,7 +39,6 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
                   .ok());
 
   EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());
-  EXPECT_FALSE(async_th->IsReady());
 
   TensorShape sync_shape;
   TensorShape async_shape;

From 51249d605d4e52754d17140ac85185119818d908 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 17:04:12 -0800
Subject: [PATCH 191/279] Improving softmax precision.

PiperOrigin-RevId: 283449048
Change-Id: I336e2f7740305aabcea02dac22c7c47d92406bf8
---
 tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
index 871cd505368..efaf39390d9 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
@@ -62,14 +62,17 @@ class Softmax : public NodeShader {
     std::string source = R"(
   highp float sum = 0.0;
   for (int d = 0; d < $src_depth$ - 1; ++d) {
-    sum += dot(vec4(1.0), exp($input_data_0[gid.x, gid.y, d]$));
+    highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
+    sum += dot(vec4(1.0), exp(v));
   }
   {
     int d = $src_depth$ - 1;
-    sum += dot($mask$, exp($input_data_0[gid.x, gid.y, d]$));
+    highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
+    sum += dot($mask$, exp(v));
   }
   for (int d = 0; d < $src_depth$; ++d) {
-    vec4 temp_sum = exp($input_data_0[gid.x, gid.y, d]$) / sum;
+    highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
+    vec4 temp_sum = exp(v) / sum;
     $output_data_0[gid.x, gid.y, d] = temp_sum$;
   }
 )";

From dd8c49a13de4528478ab8b217e92316062cc56ea Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Mon, 2 Dec 2019 17:05:40 -0800
Subject: [PATCH 192/279] Create a new BUILD target for util/reporter. For now,
 create it in core/BUILD until core/util/BUILD is created.

PiperOrigin-RevId: 283449291
Change-Id: I787751231da74c90b8d2f3d54f7d4ab9a3fc30a5
---
 tensorflow/core/BUILD            | 32 ++++++++++++++++++++++++++++++--
 tensorflow/core/util/reporter.cc |  4 ++--
 2 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6a810de58b0..e60e0f7b1a0 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -237,6 +237,7 @@ tf_proto_library(
     make_default_target_header_only = True,
     protodeps = [
         ":error_codes_proto_impl",
+        ":test_log_proto_impl",
         ":core_protos",
         "//tensorflow/core/framework:protos_all",
         "//tensorflow/core/lib/core:error_codes_proto",
@@ -567,6 +568,24 @@ cc_library(
     ],
 )
 
+# TODO(gunan): Move this to core/util/BUILD once the  file is created
+cc_library(
+    name = "util_reporter",
+    srcs = ["util/reporter.cc"],
+    hdrs = ["util/reporter.h"],
+    # Not to be used outside this file.
+    visibility = ["//visibility:private"],
+    deps = [
+        ":test_log_proto_impl_cc",
+        "//tensorflow/core/platform:env",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:str_util",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
 # Test support library needed for all tests
 # This is currently public, but may be made internal in the
 # future.  Try to avoid depending on it.
@@ -574,7 +593,6 @@ cc_library(
     name = "test",
     testonly = 1,
     srcs = [
-        "util/reporter.cc",
         "//tensorflow/core/platform:legacy_test_srcs",
     ],
     hdrs = [
@@ -2156,6 +2174,7 @@ cc_library(
     deps = tf_additional_lib_deps() + [
         ":core_stringpiece",
         ":lib_proto_parsing",
+        ":util_reporter",  # TODO(gunan): REMOVE as soon as cc_shared_library is supported.
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "//third_party/eigen3",
@@ -2450,6 +2469,15 @@ tf_proto_library(
     make_default_target_header_only = True,
 )
 
+tf_proto_library(
+    name = "test_log_proto_impl",
+    srcs = ["util/test_log.proto"],
+    cc_api_version = 2,
+    make_default_target_header_only = True,
+    # Not to be used outside this file.
+    visibility = ["//visibility:private"],
+)
+
 tf_proto_library(
     name = "core_protos",
     srcs = COMMON_PROTO_SRCS + [
@@ -2473,12 +2501,12 @@ tf_proto_library(
         "protobuf/tensorflow_server.proto",
         "protobuf/trackable_object_graph.proto",
         "protobuf/transport_options.proto",
-        "util/test_log.proto",
     ],
     cc_api_version = 2,
     make_default_target_header_only = True,
     protodeps = [
         ":error_codes_proto_impl",
+        ":test_log_proto_impl",
         "//tensorflow/core/framework:protos_all",
         "//tensorflow/core/lib/core:error_codes_proto",
         "//tensorflow/core/profiler/protobuf:xplane_proto",
diff --git a/tensorflow/core/util/reporter.cc b/tensorflow/core/util/reporter.cc
index eb69e292116..8e9d863b4c2 100644
--- a/tensorflow/core/util/reporter.cc
+++ b/tensorflow/core/util/reporter.cc
@@ -15,9 +15,9 @@ limitations under the License.
 
 #include "tensorflow/core/util/reporter.h"
 
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/str_util.h"
 
 namespace tensorflow {
 

From bfe235d23fd60d251a0cfa4325bb1b92bbf47f49 Mon Sep 17 00:00:00 2001
From: Dan Moldovan 
Date: Mon, 2 Dec 2019 17:14:31 -0800
Subject: [PATCH 193/279] Cleanup: Use a standard name for the directives
 annotation.

PiperOrigin-RevId: 283450563
Change-Id: I7db4fde547fe9d0d66c7a2e161e481f8add0b5ff
---
 tensorflow/python/autograph/converters/directives.py     | 4 ++--
 .../python/autograph/converters/directives_test.py       | 3 +--
 tensorflow/python/autograph/core/converter.py            | 9 ---------
 tensorflow/python/autograph/pyct/anno.py                 | 2 ++
 tensorflow/python/autograph/pyct/templates.py            | 1 +
 5 files changed, 6 insertions(+), 13 deletions(-)

diff --git a/tensorflow/python/autograph/converters/directives.py b/tensorflow/python/autograph/converters/directives.py
index b712c21d364..fe1c75a5864 100644
--- a/tensorflow/python/autograph/converters/directives.py
+++ b/tensorflow/python/autograph/converters/directives.py
@@ -98,9 +98,9 @@ class DirectivesTransformer(converter.Base):
       raise ValueError(
           '"%s" must be used inside a statement' % directive.__name__)
     target = self.get_local(ENCLOSING_LOOP)
-    node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {})
+    node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {})
     node_anno[directive] = _map_args(call_node, directive)
-    anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno)
+    anno.setanno(target, anno.Basic.DIRECTIVES, node_anno)
     return call_node
 
   def visit_Name(self, node):
diff --git a/tensorflow/python/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py
index 62de7d6229a..545094521ec 100644
--- a/tensorflow/python/autograph/converters/directives_test.py
+++ b/tensorflow/python/autograph/converters/directives_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
 
 from tensorflow.python.autograph.converters import directives as directives_converter
 from tensorflow.python.autograph.core import converter_testing
-from tensorflow.python.autograph.core.converter import AgAnno
 from tensorflow.python.autograph.lang import directives
 from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.autograph.pyct import parser
@@ -68,7 +67,7 @@ class DirectivesTest(converter_testing.TestCase):
     node, ctx = self.prepare(test_fn, {'directives': directives})
     node = directives_converter.transform(node, ctx)
 
-    d = anno.getanno(node.body[1], AgAnno.DIRECTIVES)
+    d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES)
     d = d[directives.set_loop_options]
     self.assertEqual(d['parallel_iterations'].n, 10)
     self.assertEqual(d['back_prop'].id, 'a')
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 3102377d638..e286e38d855 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -354,15 +354,6 @@ class AnnotatedDef(reaching_definitions.Definition):
     self.directives = {}
 
 
-class AgAnno(enum.Enum):
-  """Annotation labels specific to AutoGraph. See anno.py."""
-
-  DIRECTIVES = 'User directives associated with the annotated statement.'
-
-  def __repr__(self):
-    return self.name
-
-
 def standard_analysis(node, context, is_initial=False):
   """Performs a complete static analysis of the given code.
 
diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
index e1f4af46cd7..a8ae864cd88 100644
--- a/tensorflow/python/autograph/pyct/anno.py
+++ b/tensorflow/python/autograph/pyct/anno.py
@@ -55,6 +55,8 @@ class Basic(NoValue):
       ' `name_map` allows renaming symbols.')
   ORIGIN = ('Information about the source code that converted code originated'
             ' from. See origin_information.py.')
+  DIRECTIVES = ('User directives associated with a statement or a variable.'
+                ' Typically, they affect the immediately-enclosing statement.')
 
 
 class Static(NoValue):
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 24d2a0760b9..165319ef02b 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -120,6 +120,7 @@ class ReplaceTransformer(gast.NodeTransformer):
     self.preserved_annos = {
         anno.Basic.ORIGIN,
         anno.Basic.SKIP_PROCESSING,
+        anno.Basic.DIRECTIVES,
         anno.Static.ORIG_DEFINITIONS,
         'extra_test',
         'function_context_name',

From 9bc6a80e2b4e6fef93a0e6019e9474312d487e96 Mon Sep 17 00:00:00 2001
From: Chao Mei 
Date: Mon, 2 Dec 2019 17:29:35 -0800
Subject: [PATCH 194/279] Use op_name as the tag when recording profiling
 events so that it's more specific than the original general "OpInvoke" tag.

As a result of this, simplify the existing interpretation of recorded profiling events in profiling summarizer.

PiperOrigin-RevId: 283452539
Change-Id: Ifd4a3d27c6975aec9d48a812a41165aa7bda6e81
---
 tensorflow/lite/core/api/profiler.h           |  3 ---
 tensorflow/lite/core/subgraph.cc              | 12 +++++++++-
 .../lite/profiling/profile_summarizer.cc      | 23 ++++++-------------
 .../lite/profiling/profile_summarizer_test.cc |  2 +-
 tensorflow/lite/profiling/profiler.h          |  1 -
 tensorflow/lite/profiling/profiler_test.cc    |  6 ++---
 6 files changed, 22 insertions(+), 25 deletions(-)

diff --git a/tensorflow/lite/core/api/profiler.h b/tensorflow/lite/core/api/profiler.h
index aea70cd73f8..7bc296510d4 100644
--- a/tensorflow/lite/core/api/profiler.h
+++ b/tensorflow/lite/core/api/profiler.h
@@ -93,9 +93,6 @@ class ScopedOperatorProfile : public ScopedProfile {
   tflite::ScopedOperatorProfile TFLITE_VARNAME_UNIQ(_profile_, __COUNTER__)( \
       (profiler), (tag), (node_index))
 
-#define TFLITE_SCOPED_OPERATOR_PROFILE(profiler, node_index) \
-  TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE((profiler), "OpInvoke", (node_index))
-
 #define TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(profiler, node_index)   \
   TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE((profiler), "DelegateOpInvoke", \
                                         (node_index))
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 38a9d24d782..e453ff2ff7e 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -758,7 +758,17 @@ TfLiteStatus Subgraph::Invoke() {
     TfLiteNode& node = nodes_and_registration_[node_index].first;
     const TfLiteRegistration& registration =
         nodes_and_registration_[node_index].second;
-    TFLITE_SCOPED_OPERATOR_PROFILE(profiler_.get(), node_index);
+
+    const char* op_name = nullptr;
+    if (profiler_) {
+      if (registration.builtin_code == tflite::BuiltinOperator_CUSTOM) {
+        const char* const custom_name = registration.custom_name;
+        op_name = custom_name ? custom_name : "UnknownCustomOp";
+      } else {
+        op_name = tflite::EnumNamesBuiltinOperator()[registration.builtin_code];
+      }
+    }
+    TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE(profiler_.get(), op_name, node_index);
 
     // TODO(ycling): This is an extra loop through inputs to check if the data
     // need to be copied from Delegate buffer to raw memory, which is often not
diff --git a/tensorflow/lite/profiling/profile_summarizer.cc b/tensorflow/lite/profiling/profile_summarizer.cc
index d69b0a697d7..b004bc2e361 100644
--- a/tensorflow/lite/profiling/profile_summarizer.cc
+++ b/tensorflow/lite/profiling/profile_summarizer.cc
@@ -27,7 +27,7 @@ namespace {
 struct OperatorDetails {
   uint32_t subgraph_index;
   uint32_t node_index;
-  std::string name;
+  std::string op_description;
   std::vector inputs;
   std::vector outputs;
 };
@@ -74,20 +74,11 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
   auto node_reg = subgraph->node_and_registration(node_index);
   auto inputs = node_reg->first.inputs;
   auto outputs = node_reg->first.outputs;
-  int code = node_reg->second.builtin_code;
-  const char* op_name = nullptr;
-  if (code == tflite::BuiltinOperator_CUSTOM) {
-    const char* custom_name = node_reg->second.custom_name;
-    op_name = custom_name ? custom_name : "UnknownCustomOp";
-  } else {
-    op_name = tflite::EnumNamesBuiltinOperator()[code];
-  }
   const char* profiling_string =
       interpreter.OpProfilingString(node_reg->second, &node_reg->first);
   OperatorDetails details;
-  details.name = op_name;
   if (profiling_string) {
-    details.name += ":" + std::string(profiling_string);
+    details.op_description = std::string(profiling_string);
   }
   details.inputs = GetTensorNames(interpreter, inputs);
   details.outputs = GetTensorNames(interpreter, outputs);
@@ -132,9 +123,6 @@ void ProfileSummarizer::ProcessProfiles(
 
   int64_t base_start_us = events[0]->begin_timestamp_us;
   int node_num = 0;
-  auto tag_string = [](const string& s, const string& t) {
-    return (t == "OpInvoke" || t == "DelegateOpInvoke") ? s : s + "/" + t;
-  };
 
   // Total time will be accumulated per subgraph.
   std::map total_us_per_subgraph_map;
@@ -154,13 +142,16 @@ void ProfileSummarizer::ProcessProfiles(
 
       const auto op_details =
           GetOperatorDetails(interpreter, subgraph_index, node_index);
-      const auto type_in_stats = tag_string(op_details.name, event->tag);
+      std::string type_in_stats(event->tag);
+      if (!op_details.op_description.empty()) {
+        type_in_stats += "/" + op_details.op_description;
+      }
 
       const auto node_name = ToString(op_details.outputs);
       // Append node index to node name because 'stats_calculator' can not
       // distinguish two nodes w/ the same 'node_name'.
       const auto node_name_in_stats =
-          tag_string(node_name + ":" + std::to_string(node_index), event->tag);
+          node_name + ":" + std::to_string(node_index);
 
       stats_calculator->AddNodeStats(node_name_in_stats, type_in_stats,
                                      node_num, start_us, node_exec_time,
diff --git a/tensorflow/lite/profiling/profile_summarizer_test.cc b/tensorflow/lite/profiling/profile_summarizer_test.cc
index 6340921bc0e..0c4b9fcd88f 100644
--- a/tensorflow/lite/profiling/profile_summarizer_test.cc
+++ b/tensorflow/lite/profiling/profile_summarizer_test.cc
@@ -141,7 +141,7 @@ TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) {
   summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
   auto output = summarizer.GetOutputString();
   // TODO(shashishekhar): Add a better test here.
-  ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos)
+  ASSERT_TRUE(output.find("SimpleOpEval/Profile") != std::string::npos)
       << output;
 }
 
diff --git a/tensorflow/lite/profiling/profiler.h b/tensorflow/lite/profiling/profiler.h
index e75c90bf6b6..ff398698616 100644
--- a/tensorflow/lite/profiling/profiler.h
+++ b/tensorflow/lite/profiling/profiler.h
@@ -32,6 +32,5 @@ using Profiler = NoopProfiler;
 }  // namespace tflite
 
 #define SCOPED_TAGGED_OPERATOR_PROFILE TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE
-#define SCOPED_OPERATOR_PROFILE TFLITE_SCOPED_OPERATOR_PROFILE
 
 #endif  // TENSORFLOW_LITE_PROFILING_PROFILER_H_
diff --git a/tensorflow/lite/profiling/profiler_test.cc b/tensorflow/lite/profiling/profiler_test.cc
index 57da951c8ce..cedb109697d 100644
--- a/tensorflow/lite/profiling/profiler_test.cc
+++ b/tensorflow/lite/profiling/profiler_test.cc
@@ -97,13 +97,13 @@ TEST(ProfilingTest, ProfilesAreCollected) {
 
 TEST(ProfilingTest, NullProfiler) {
   Profiler* profiler = nullptr;
-  { SCOPED_OPERATOR_PROFILE(profiler, 1); }
+  { SCOPED_TAGGED_OPERATOR_PROFILE(profiler, "noop", 1); }
 }
 
 TEST(ProfilingTest, ScopedProfile) {
   BufferedProfiler profiler(1024);
   profiler.StartProfiling();
-  { SCOPED_OPERATOR_PROFILE(&profiler, 1); }
+  { SCOPED_TAGGED_OPERATOR_PROFILE(&profiler, "noop", 1); }
   profiler.StopProfiling();
   auto profile_events = profiler.GetProfileEvents();
   EXPECT_EQ(1, profile_events.size());
@@ -112,7 +112,7 @@ TEST(ProfilingTest, ScopedProfile) {
 TEST(ProfilingTest, NoopProfiler) {
   NoopProfiler profiler;
   profiler.StartProfiling();
-  { SCOPED_OPERATOR_PROFILE(&profiler, 1); }
+  { SCOPED_TAGGED_OPERATOR_PROFILE(&profiler, "noop", 1); }
   profiler.StopProfiling();
   auto profile_events = profiler.GetProfileEvents();
   EXPECT_EQ(0, profile_events.size());

From 41a576f5051e6e4a1afae4931ad2d38f6568f4aa Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Mon, 2 Dec 2019 17:34:52 -0800
Subject: [PATCH 195/279] Throw an explicit error if user call TPUStrategy
 experimental_run_v2 in eager mode with a python function.

PiperOrigin-RevId: 283453288
Change-Id: I381a61afbaf6cb74ccd1ad1f556d8e5cf3f962f2
---
 .../distribute/custom_training_loop_test.py   |  5 +-
 tensorflow/python/distribute/tpu_strategy.py  | 60 ++------------
 tensorflow/python/distribute/values_test.py   | 79 +++----------------
 3 files changed, 21 insertions(+), 123 deletions(-)

diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py
index e9b283d376c..1db9bff21f0 100644
--- a/tensorflow/python/distribute/custom_training_loop_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_test.py
@@ -36,8 +36,9 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
 
   @combinations.generate(
       combinations.combine(
-          distribution=strategy_combinations.strategies_minus_tpu,
-          mode=["eager"]))
+          distribution=strategy_combinations.all_strategies,
+          mode=["eager"]
+      ))
   def testFullEager(self, distribution):
     dataset = self._get_dataset()
 
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 8f32e8e2226..2dd4309537a 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -37,7 +37,6 @@ from tensorflow.python.distribute import values
 from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
-from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import device_spec
 from tensorflow.python.framework import dtypes
@@ -83,29 +82,6 @@ def maybe_init_scope():
       yield
 
 
-def validate_experimental_run_function(fn):
-  """Validate the function passed into strategy.experimental_run_v2."""
-
-  # We allow three types of functions/objects passed into TPUStrategy
-  # experimental_run_v2 in eager mode:
-  #   1. a user annotated tf.function
-  #   2. a ConcreteFunction, this is mostly what you get from loading a saved
-  #      model.
-  #   3. a callable object and the `__call__` method itself is a tf.function.
-  #
-  # Otherwise we return an error, because we don't support eagerly running
-  # experimental_run_v2 in TPUStrategy.
-
-  if context.executing_eagerly() and not isinstance(
-      fn, def_function.Function) and not isinstance(
-          fn, function.ConcreteFunction) and not (callable(fn) and isinstance(
-              fn.__call__, def_function.Function)):
-    raise NotImplementedError(
-        "TPUStrategy.experimental_run_v2(fn, ...) does not support eager "
-        "execution. Either convert `fn` into a tf.function or consider "
-        "calling strategy.experimental_run_v2 inside a tf.function.")
-
-
 @tf_export("distribute.experimental.TPUStrategy", v1=[])
 class TPUStrategy(distribute_lib.Strategy):
   """TPU distribution strategy implementation."""
@@ -113,36 +89,14 @@ class TPUStrategy(distribute_lib.Strategy):
   def __init__(self,
                tpu_cluster_resolver=None,
                device_assignment=None):
-    """Synchronous training in TPU donuts or Pods.
-
-    To construct a TPUStrategy object, you need to run the
-    initialization code as below:
-
-    ```python
-    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
-    tf.config.experimental_connect_to_cluster(resolver)
-    tf.tpu.experimental.initialize_tpu_system(resolver)
-    strategy = tf.distribute.experimental.TPUStrategy(resolver)
-    ```
-
-    While using distribution strategies, the variables created within strategy's
-    scope will be replicated across all the replicas and can be kept in sync
-    using all-reduce algorithms.
-
-    To run TF2 programs on TPUs, you can either use `.compile` and
-    `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
-    training loop by calling `strategy.experimental_run_v2` directly. Note that
-    TPUStrategy doesn't support pure eager execution, so please make sure the
-    function passed into `strategy.experimental_run_v2` is a `tf.function` or
-    `strategy.experimental_run_v2` us called inside a `tf.function` if running
-    in eager mode.
+    """Initializes the TPUStrategy object.
 
     Args:
       tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
-        which provides information about the TPU cluster.
+          which provides information about the TPU cluster.
       device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
-        specify the placement of replicas on the TPU cluster. Currently only
-        supports the usecase of using a single core within a TPU cluster.
+          specify the placement of replicas on the TPU cluster. Currently only
+          supports the usecase of using a single core within a TPU cluster.
     """
     super(TPUStrategy, self).__init__(TPUExtended(
         self, tpu_cluster_resolver, device_assignment=device_assignment))
@@ -157,8 +111,6 @@ class TPUStrategy(distribute_lib.Strategy):
   # This implementation runs a single step. It does not use infeed or outfeed.
   def experimental_run_v2(self, fn, args=(), kwargs=None):
     """See base class."""
-    validate_experimental_run_function(fn)
-
     # Note: the target function is converted to graph even when in Eager mode,
     # so autograph is on by default here.
     fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
@@ -205,8 +157,6 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
   # This implementation runs a single step. It does not use infeed or outfeed.
   def experimental_run_v2(self, fn, args=(), kwargs=None):
     """See base class."""
-    validate_experimental_run_function(fn)
-
     fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
     return self.extended.tpu_run(fn, args, kwargs)
 
@@ -749,7 +699,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
         ]
 
       # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
-      if result[0] is None or isinstance(result[0], ops.Operation):
+      if result[0] is None:
         replicate_outputs = [None] * len(replicate_outputs)
       else:
         replicate_outputs = [
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index 26d0eb3ac32..d97d1155c82 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -818,31 +818,13 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
     self.assertEqual(2., self.evaluate(add1(replica_local)))
 
 
-def mirrored_and_tpu_strategy_combinations():
-  return combinations.combine(
-      distribution=[
-          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-          strategy_combinations.tpu_strategy,
-      ],
-      mode=["graph", "eager"])
-
-
-def strategy_and_run_tf_function_combinations():
-  # Test the combination of different strategies and whether a tf.function
-  # is passed into strategy.experimental_run_v2."""
-  return combinations.combine(
-      distribution=[
-          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-      ],
-      mode=["graph", "eager"],
-      experimental_run_tf_function=[True, False]) + combinations.combine(
-          distribution=[
-              strategy_combinations.tpu_strategy,
-          ],
-          mode=["graph", "eager"],
-          experimental_run_tf_function=[True])
-
-
+@combinations.generate(
+    combinations.combine(
+        distribution=[
+            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+            strategy_combinations.tpu_strategy,
+        ],
+        mode=["graph", "eager"]))
 class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
 
   def _assign_replica_local(self, v, new):
@@ -860,7 +842,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
     save_path, _ = self._save_return_saver(sess, var)
     return save_path
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
     with self.cached_session() as sess:
       v, replica_local = _make_replica_local(
@@ -881,7 +862,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         saver.restore(sess, save_path)
         self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
     if context.num_gpus() < 1 and context.executing_eagerly():
       self.skipTest("A GPU is not available for this test in eager mode.")
@@ -998,46 +978,36 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         saver.restore(sess, save_path)
         self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
     save_path = self._save_replica_local_mean(distribution)
     self._restore_replica_local_mean(save_path, distribution)
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
     save_path = self._save_replica_local_sum(distribution)
     self._restore_replica_local_sum(save_path, distribution)
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
     save_path = self._save_replica_local_mean(distribution)
     self._restore_normal(save_path)
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveReplicaLocalSumRestoreNormal(self, distribution):
     save_path = self._save_replica_local_sum(distribution)
     self._restore_normal(save_path)
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveNormalRestoreReplicaLocalMean(self, distribution):
     save_path = self._save_normal()
     self._restore_replica_local_mean(save_path, distribution)
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testSaveNormalRestoreReplicaLocalSum(self, distribution):
     save_path = self._save_normal()
     self._restore_replica_local_sum(save_path, distribution)
 
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssign(self, distribution, experimental_run_tf_function):
-
+  def testAssign(self, distribution):
     def assign(fn, v, update_value, cross_replica):
       update_fn = lambda: getattr(v, fn)(update_value)
       if cross_replica:
         return update_fn()
       else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
         return distribution.experimental_local_results(
             distribution.experimental_run_v2(update_fn))
     updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
@@ -1063,17 +1033,12 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         self.assertAllEqual(self.evaluate(component.read_value()),
                             self.evaluate(array_ops.ones_like(component)))
 
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignDtypeConversion(self, distribution,
-                                experimental_run_tf_function):
-
+  def testAssignDtypeConversion(self, distribution):
     def assign(fn, v, update_value, cross_replica):
       update_fn = lambda: getattr(v, fn)(update_value)
       if cross_replica:
         return update_fn()
       else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
         return distribution.experimental_local_results(
             distribution.experimental_run_v2(update_fn))
     updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
@@ -1099,7 +1064,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         self.assertAllEqual(self.evaluate(component.read_value()),
                             self.evaluate(array_ops.ones_like(component)))
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testAssignWithAggregationSum(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
@@ -1112,7 +1076,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
       self.assertAllEqual(self.evaluate(component.read_value()),
                           self.evaluate(array_ops.ones_like(component)))
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testAssignAddSubWithAggregationSum(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
@@ -1127,9 +1090,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         ValueError, "SyncOnReadVariable does not support "):
       self.evaluate(v.assign_sub(1.))
 
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testReadValueInReplicaContext(self, distribution,
-                                    experimental_run_tf_function):
+  def testReadValueInReplicaContext(self, distribution):
     aggregations = [
         variables_lib.VariableAggregation.NONE,
         variables_lib.VariableAggregation.SUM,
@@ -1143,19 +1104,12 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
             synchronization=variables_lib.VariableSynchronization.ON_READ,
             aggregation=aggregation)
       self.evaluate(variables_lib.global_variables_initializer())
-      if experimental_run_tf_function:
-        read_var_fn = def_function.function(v.read_value)
-      else:
-        read_var_fn = v.read_value
-      results = self.evaluate(
-          distribution.experimental_local_results(
-              distribution.experimental_run_v2(read_var_fn)))
+      results = self.evaluate(distribution.experimental_local_results(
+          distribution.experimental_run_v2(v.read_value)))
       for component, value in zip(v._values, results):
         self.assertAllEqual(self.evaluate(component.read_value()), value)
 
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testReadValueInCrossReplicaContext(self, distribution,
-                                         experimental_run_tf_function):
+  def testReadValueInCrossReplicaContext(self, distribution):
     aggregations = [
         variables_lib.VariableAggregation.SUM,
         variables_lib.VariableAggregation.MEAN,
@@ -1171,15 +1125,10 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
             synchronization=variables_lib.VariableSynchronization.ON_READ,
             aggregation=aggregation)
       self.evaluate(variables_lib.global_variables_initializer())
-
       def assign(v=v):
         ctx = distribution_strategy_context.get_replica_context()
         replica_id = ctx.replica_id_in_sync_group
         return v.assign(math_ops.cast(replica_id, dtypes.float32))
-
-      if experimental_run_tf_function:
-        assign = def_function.function(assign)
-
       self.evaluate(distribution.experimental_local_results(
           distribution.experimental_run_v2(assign)))
       result = self.evaluate(v.read_value())
@@ -1193,7 +1142,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         expected = 0
       self.assertEqual(expected, result, aggregation)
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
@@ -1205,7 +1153,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
         ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
       self.evaluate(v.read_value())
 
-  @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testInitializedToSameValueInsideEagerRun(self, distribution):
     if not context.executing_eagerly(): self.skipTest("eager only")
 

From 489a7c85b889d6c4b1704ba8cd3876747cffb8a7 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy 
Date: Mon, 2 Dec 2019 17:39:42 -0800
Subject: [PATCH 196/279] Remove dependence on core/lib/core/blocking_counter.

the library has moved to core/platform

PiperOrigin-RevId: 283453966
Change-Id: I7e0248a318b54bd71d22a6c40eac5613a545bcbf
---
 tensorflow/core/BUILD                                       | 1 +
 tensorflow/core/platform/cloud/BUILD                        | 1 +
 tensorflow/core/platform/cloud/ram_file_block_cache_test.cc | 4 +++-
 tensorflow/core/platform/default/build_refactor.bzl         | 2 +-
 tensorflow/core/platform/threadpool.cc                      | 2 +-
 tensorflow/core/platform/unbounded_work_queue_test.cc       | 2 +-
 6 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index e60e0f7b1a0..f07251955dc 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2111,6 +2111,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = [
     "//tensorflow/core/lib/random:legacy_lib_internal_public_random_headers",
     "//tensorflow/core/lib/strings:legacy_lib_internal_public_string_headers",
     "lib/wav/wav_io.h",
+    "//tensorflow/core/platform:blocking_counter.h",
     "//tensorflow/core/platform:demangle.h",
     "//tensorflow/core/platform:denormal.h",
     "//tensorflow/core/platform:host_info.h",
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index e38c51974fb..7b194e78911 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -351,6 +351,7 @@ tf_cc_test(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:blocking_counter",
     ],
 )
 
diff --git a/tensorflow/core/platform/cloud/ram_file_block_cache_test.cc b/tensorflow/core/platform/cloud/ram_file_block_cache_test.cc
index 9f37be65943..e018333b1b7 100644
--- a/tensorflow/core/platform/cloud/ram_file_block_cache_test.cc
+++ b/tensorflow/core/platform/cloud/ram_file_block_cache_test.cc
@@ -14,9 +14,11 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/cloud/ram_file_block_cache.h"
+
 #include 
-#include "tensorflow/core/lib/core/blocking_counter.h"
+
 #include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/cloud/now_seconds_env.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/notification.h"
diff --git a/tensorflow/core/platform/default/build_refactor.bzl b/tensorflow/core/platform/default/build_refactor.bzl
index 6d1beca6923..a29cce63fd7 100644
--- a/tensorflow/core/platform/default/build_refactor.bzl
+++ b/tensorflow/core/platform/default/build_refactor.bzl
@@ -75,10 +75,10 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
             "@com_google_absl//absl/time",
             "@com_google_absl//absl/types:optional",
             "//third_party/eigen3",
-            "//tensorflow/core/lib/core:blocking_counter",
             "//tensorflow/core/lib/core:error_codes_proto_cc",
             "//tensorflow/core/lib/core:stringpiece",
             "//tensorflow/core/platform",
+            "//tensorflow/core/platform:blocking_counter",
             "//tensorflow/core/platform:context",
             "//tensorflow/core/platform:cord",
             "//tensorflow/core/platform:denormal",
diff --git a/tensorflow/core/platform/threadpool.cc b/tensorflow/core/platform/threadpool.cc
index fa22ad3867b..18aa7684aba 100644
--- a/tensorflow/core/platform/threadpool.cc
+++ b/tensorflow/core/platform/threadpool.cc
@@ -19,7 +19,7 @@ limitations under the License.
 
 #include "absl/types/optional.h"
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/context.h"
 #include "tensorflow/core/platform/denormal.h"
 #include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/core/platform/unbounded_work_queue_test.cc b/tensorflow/core/platform/unbounded_work_queue_test.cc
index 03d91cd4893..ada99c5e1a3 100644
--- a/tensorflow/core/platform/unbounded_work_queue_test.cc
+++ b/tensorflow/core/platform/unbounded_work_queue_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/platform/unbounded_work_queue.h"
 
 #include "absl/memory/memory.h"
-#include "tensorflow/core/lib/core/blocking_counter.h"
 #include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {

From e36d1de5d21e1242ec247eefed69d3d286709990 Mon Sep 17 00:00:00 2001
From: Smit Hinsu 
Date: Mon, 2 Dec 2019 17:40:57 -0800
Subject: [PATCH 197/279] Handle 32 bit integer paddings operand for tf.PadV2
 lowering

Updated lowering helpers to always create xla_hlo.pad attributes of 64 bit integer type.

PiperOrigin-RevId: 283454108
Change-Id: I1647754a32ebe7ea7b971e6f9fd560eb87f43d3d
---
 .../compiler/mlir/xla/tests/legalize-tf.mlir   | 11 +++++++++++
 .../mlir/xla/transforms/legalize_tf.cc         | 17 +++++++++++++++--
 .../xla/transforms/legalize_tf_patterns.td     | 18 +++++++-----------
 3 files changed, 33 insertions(+), 13 deletions(-)

diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index d8093f1a39a..94a445fe8bd 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -624,6 +624,17 @@ func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> {
   return %1 : tensor<6x9xf32>
 }
 
+// CHECK-LABEL: func @padv2_i32_paddings
+func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> {
+  %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32>
+  // CHECK: "xla_hlo.pad"(%arg0, %arg1) {
+  // CHECK-SAME:    edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
+  // CHECK-SAME:    edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
+  // CHECK-SAME:    interior_padding = dense<0> : tensor<2xi64>
+  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor) -> tensor<6x9xf32>
+  return %1 : tensor<6x9xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // Identity op legalizations.
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 94e0ce35cb0..d7f3bf243e5 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -232,11 +232,14 @@ static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
 // Pad op utilities.
 //===----------------------------------------------------------------------===//
 
+// Slices input attribute of rank two and returns the specified column.
+//
+// Always returns 64 bit integer attribute regardless of bitwidth of the input
+// attribute.
 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
-    Builder *b, ElementsAttr input, int column) {
+    ElementsAttr input, int column) {
   auto int_attr = input.cast();
   auto shaped_type = int_attr.getType();
-  auto element_type = shaped_type.getElementType();
   auto shape = shaped_type.getShape();
 
   if (shape.size() != 2) return DenseIntElementsAttr();
@@ -250,10 +253,20 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
     }
   }
 
+  auto element_type = IntegerType::get(64, input.getContext());
   return DenseIntElementsAttr::get(
       RankedTensorType::get({shape[0]}, element_type), values);
 }
 
+// Returns interior padding to use in HLO Pad op based on the TensorFlow padding
+// in TensorFlow PadV2 op.
+static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
+  auto length = tf_padding.getType().getShape()[0];
+  auto element_type = IntegerType::get(64, tf_padding.getContext());
+  return DenseIntElementsAttr::get(
+      RankedTensorType::get({length}, element_type), 0);
+}
+
 //===----------------------------------------------------------------------===//
 // Binary op utilities.
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index 24d24e864d9..fb8c6736309 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -256,25 +256,21 @@ def : Pat<(TF_RFFTOp $input, (TF_ConstOp I32ElementsAttr:$fft_length)),
 // Pad op patterns.
 //===----------------------------------------------------------------------===//
 
-def ZeroPaddingAttr : NativeCodeCall <
-  "DenseIntElementsAttr::get("
-    "RankedTensorType::get($0.getType().getShape()[0],"
-    "                      getElementTypeOrSelf($0.getType())), "
-    "{$_builder.getZeroAttr(getElementTypeOrSelf($0.getType()))})">;
-
 class SliceDenseIntElementsAttrColumn2D : NativeCodeCall<
-  "SliceDenseIntElementsAttrColumn2D("
-    "&$_builder, $0, " # column # " )">;
+  "SliceDenseIntElementsAttrColumn2D($0, " # column # " )">;
 
 class SliceDenseIntElementsAttr : NativeCodeCall<
-  "SliceDenseIntElementsAttr(&$_builder, $0, " # index # ", " # axis # ")">;
+  "SliceDenseIntElementsAttr($0, " # index # ", " # axis # ")">;
 
+// Interior padding attribute based on the TF padding.
+def GetInteriorPadding : NativeCodeCall <
+  "GetInteriorPadding($0)">;
 
-def : Pat<(TF_PadV2Op $input, (TF_ConstOp I64ElementsAttr:$padding), $c),
+def : Pat<(TF_PadV2Op $input, (TF_ConstOp $padding), $c),
           (HLO_PadOp $input, $c,
            (SliceDenseIntElementsAttrColumn2D<"0"> $padding),
            (SliceDenseIntElementsAttrColumn2D<"1"> $padding),
-           (ZeroPaddingAttr $padding))>;
+           (GetInteriorPadding $padding))>;
 
 //===----------------------------------------------------------------------===//
 // Identity op patterns.

From e54cde4b433d2b7fbeb3b4614aa874ea0388a509 Mon Sep 17 00:00:00 2001
From: Juhyun Lee 
Date: Mon, 2 Dec 2019 18:22:57 -0800
Subject: [PATCH 198/279] Make shape_inference_helper buildable for Android.

PiperOrigin-RevId: 283459430
Change-Id: I67ebedb6559125797e3856db557df579824a007c
---
 tensorflow/compiler/jit/BUILD | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 89e6a04abd7..bcdfa019459 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -231,7 +231,14 @@ cc_library(
     srcs = ["shape_inference_helpers.cc"],
     hdrs = ["shape_inference_helpers.h"],
     visibility = [":friends"],
-    deps = ["//tensorflow/core:graph"],
+    deps = select({
+        "//tensorflow:android": [
+            "//tensorflow/core:android_tensorflow_lib",
+        ],
+        "//conditions:default": [
+            "//tensorflow/core:graph",
+        ],
+    }),
 )
 
 # Internal targets below this point.

From 8eaf31f8f68bc5eb940cd490e2295dc5f2b597d7 Mon Sep 17 00:00:00 2001
From: Yunlu Li 
Date: Mon, 2 Dec 2019 19:10:40 -0800
Subject: [PATCH 199/279] Fix tensor initialization.

PiperOrigin-RevId: 283464230
Change-Id: Ia5cf3eb3442af7f1f25a6c678a563ae2fd89bd8b
---
 tensorflow/lite/experimental/kernels/hashtable_ops_test.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc b/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc
index ddc31a5abfd..8790a2c9960 100644
--- a/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc
+++ b/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc
@@ -695,7 +695,7 @@ TEST(HashtableOpsTest, TestHashtable) {
 
 template 
 TfLiteTensor CreateTensor(TfLiteType type, std::vector vec) {
-  TfLiteTensor tensor;
+  TfLiteTensor tensor = {};
   TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
   dims->data[0] = vec.size();
   tensor.dims = dims;

From 1c03f956c5851b60fe4fcf16fb4e50bed0fb653b Mon Sep 17 00:00:00 2001
From: Peter Hawkins 
Date: Mon, 2 Dec 2019 19:42:56 -0800
Subject: [PATCH 200/279] [XLA] Refactor Executable::ExecuteAsyncOnStream.

Change implementations of Executable to always implement the overload that takes a std::vector>. Make the non-owning version a wrapper around the maybe-owning version.

Simplification in preparation for plumbing buffer donation into JAX. This change is also a necessary preparatory step for implementing buffer donation on CPU and GPU.

PiperOrigin-RevId: 283467139
Change-Id: I9f59ce6ef4405e3849a2f2ad1ab5a38419125c90
---
 tensorflow/compiler/xla/service/cpu/BUILD     |  7 ++++
 .../xla/service/cpu/cpu_executable.cc         | 40 ++++++++++++------
 .../compiler/xla/service/cpu/cpu_executable.h | 14 ++++---
 tensorflow/compiler/xla/service/executable.cc | 38 +++++++++++++----
 tensorflow/compiler/xla/service/executable.h  | 10 ++---
 .../xla/service/gpu/gpu_executable.cc         | 39 ++++++++++--------
 .../compiler/xla/service/gpu/gpu_executable.h |  9 +---
 .../service/hlo_input_output_alias_config.cc  |  3 +-
 .../compiler/xla/service/interpreter/BUILD    |  5 +++
 .../xla/service/interpreter/executable.cc     | 41 +++++++++++++++----
 .../xla/service/interpreter/executable.h      |  4 +-
 .../xla/service/maybe_owning_device_memory.cc |  3 +-
 .../xla/service/maybe_owning_device_memory.h  |  2 +-
 13 files changed, 146 insertions(+), 69 deletions(-)

diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 411ae8f7d64..e3aa1551b8a 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -242,9 +242,16 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
         "//tensorflow/compiler/xla/service:hlo_execution_profile",
         "//tensorflow/compiler/xla/service:logical_buffer",
+        "//tensorflow/compiler/xla/service:maybe_owning_device_memory",
         "//tensorflow/compiler/xla/service:shaped_buffer",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/core/platform:env",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:platform_port",
+        "//tensorflow/core/platform:types",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/stream_executor:device_memory_allocator",
         "//tensorflow/stream_executor/host:host_stream",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 9b79e8ca8d7..083c3d31d74 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -32,6 +32,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
 #include "tensorflow/compiler/xla/shape_tree.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -44,6 +45,7 @@ limitations under the License.
 #include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/device_memory_allocator.h"
 #include "tensorflow/stream_executor/host/host_stream.h"
 
 namespace xla {
@@ -73,11 +75,12 @@ CpuExecutable::CpuExecutable(
           << reinterpret_cast(compute_function_);
 }
 
-StatusOr,
-                   std::vector>>
+StatusOr,
+                    std::vector,
+                    std::vector>>
 CpuExecutable::CreateBufferTable(
     se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
-    absl::Span arguments) {
+    std::vector> arguments) {
   std::vector unowning_buffers(
       assignment_->Allocations().size());
   std::vector owning_buffers(
@@ -91,8 +94,9 @@ CpuExecutable::CreateBufferTable(
     VLOG(3) << allocation.ToString();
 
     if (allocation.is_entry_computation_parameter()) {
-      unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
-          allocation.param_shape_index());
+      unowning_buffers[i] = arguments[allocation.parameter_number()]
+                                .element(allocation.param_shape_index())
+                                .AsDeviceMemoryBase();
       CHECK_EQ(allocation.size(), unowning_buffers[i].size())
           << "Size mismatch on param " << allocation.parameter_number()
           << " at shape index " << allocation.param_shape_index().ToString();
@@ -134,7 +138,17 @@ CpuExecutable::CreateBufferTable(
                       assignment_->GetUniqueTopLevelOutputSlice());
   VLOG(3) << "result index: " << result_slice.index();
 
-  return {{std::move(unowning_buffers), std::move(owning_buffers)}};
+  std::vector buffers_to_free;
+  for (ShapeTree& argument : arguments) {
+    for (std::pair& buffer : argument) {
+      auto maybe_owning_buffer = buffer.second.Release();
+      if (maybe_owning_buffer) {
+        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
+      }
+    }
+  }
+  return {{std::move(unowning_buffers), std::move(owning_buffers),
+           std::move(buffers_to_free)}};
 }
 
 Status CpuExecutable::ExecuteComputeFunction(
@@ -268,9 +282,9 @@ StatusOr CpuExecutable::CreateResultShapedBuffer(
   return std::move(result_buffer);
 }
 
-StatusOr CpuExecutable::ExecuteAsyncOnStream(
+StatusOr CpuExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
+    std::vector> arguments,
     HloExecutionProfile* hlo_execution_profile) {
   if (GetRootValueSet().IsAmbiguous()) {
     return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -283,7 +297,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
     for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
       const Shape& expected_shape =
           entry_comp->parameter_instruction(i)->shape();
-      const Shape& actual_shape = arguments[i]->on_device_shape();
+      const Shape& actual_shape = arguments[i].shape();
       CHECK(expected_shape == actual_shape) << absl::StreamFormat(
           "Shape mismatch on argument %d.  Expected %s, but was %s.", i,
           expected_shape.ToString(/*print_layout=*/true),
@@ -297,10 +311,11 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
   se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
   std::vector owning_buffers;
   std::vector unowning_buffers;
+  std::vector buffers_to_release;
   TF_ASSIGN_OR_RETURN(
-      std::tie(unowning_buffers, owning_buffers),
+      std::tie(unowning_buffers, owning_buffers, buffers_to_release),
       CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
-                        arguments));
+                        std::move(arguments)));
 
   TF_ASSIGN_OR_RETURN(
       ScopedShapedBuffer result,
@@ -339,7 +354,8 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
                        std::move(owning_buffers)),
                    hlo_execution_profile});
 
-  return std::move(result);
+  return ExecutionOutput(std::move(result), std::move(buffers_to_release), {},
+                         se::OwningDeviceMemory());
 }
 
 /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 37af630a2d9..6f8a7c3315a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -55,9 +55,9 @@ class CpuExecutable : public Executable {
                 std::unique_ptr hlo_profile_index_map);
   ~CpuExecutable() override {}
 
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      absl::Span arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
   // This should be called after set_ir_module_string.
@@ -96,11 +96,15 @@ class CpuExecutable : public Executable {
   //    allocated by this routine.  This routine allocates buffers for temporary
   //    storage and the live-out buffer into which the computation writes it
   //    result.
-  StatusOr,
-                     std::vector>>
+  //
+  //  - buffers_to_free: buffers whose ownership was donated by the caller that
+  //    are to be freed by the caller.
+  StatusOr,
+                      std::vector,
+                      std::vector>>
   CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
                     int device_ordinal,
-                    absl::Span arguments);
+                    std::vector> arguments);
 
   // Calls the generated function performing the computation with the given
   // arguments using the supplied buffers.
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index c21721c9339..9ece6172d12 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/debug_options_flags.h"
 #include "tensorflow/compiler/xla/service/dump.h"
 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -43,9 +44,36 @@ StatusOr Executable::ExecuteOnStream(
   return result;
 }
 
+static ShapeTree MakeMaybeOwningDeviceMemoryTree(
+    const ShapedBuffer& shaped_buffer) {
+  ShapeTree result(shaped_buffer.on_device_shape());
+  auto in_it = shaped_buffer.buffers().begin();
+  auto out_it = result.begin();
+  for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) {
+    DCHECK(out_it != result.end());
+    out_it->second = MaybeOwningDeviceMemory(in_it->second);
+  }
+  return result;
+}
+
+StatusOr Executable::ExecuteAsyncOnStream(
+    const ServiceExecutableRunOptions* run_options,
+    absl::Span arguments,
+    HloExecutionProfile* hlo_execution_profile) {
+  std::vector> args(arguments.size());
+  auto out_it = args.begin();
+  for (const ShapedBuffer* arg : arguments) {
+    *out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
+  }
+  TF_ASSIGN_OR_RETURN(ExecutionOutput out,
+                      ExecuteAsyncOnStream(run_options, std::move(args),
+                                           hlo_execution_profile));
+  return out.ConsumeResult();
+}
+
 StatusOr Executable::ExecuteOnStream(
     const ServiceExecutableRunOptions* run_options,
-    std::vector> arguments,
+    std::vector> arguments,
     HloExecutionProfile* hlo_execution_profile) {
   StatusOr result = ExecuteAsyncOnStream(
       run_options, std::move(arguments), hlo_execution_profile);
@@ -55,14 +83,6 @@ StatusOr Executable::ExecuteOnStream(
   return result;
 }
 
-StatusOr Executable::ExecuteAsyncOnStream(
-    const ServiceExecutableRunOptions* /*run_options*/,
-    std::vector> /*arguments*/,
-    HloExecutionProfile* /*hlo_execution_profile*/) {
-  return Unimplemented(
-      "MaybeOwningDeviceMemory version of overload is not implemented ");
-}
-
 StatusOr> Executable::ExecuteOnStreams(
     absl::Span run_options,
     absl::Span> arguments) {
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 971dab95bfd..496599e7aaf 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -160,22 +160,22 @@ class Executable {
   // If the hlo_execution_profile is provided as non-nullptr, profiling will be
   // enabled. Note that profiling is tricky to use correctly, as the profiling
   // objects (when they exist) must out-live the task.
-  virtual StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
       absl::Span arguments,
-      HloExecutionProfile* hlo_execution_profile) = 0;
+      HloExecutionProfile* hlo_execution_profile);
 
   // Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to
   // complete.
   StatusOr ExecuteOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile);
 
   virtual StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
-      HloExecutionProfile* hlo_execution_profile);
+      std::vector> arguments,
+      HloExecutionProfile* hlo_execution_profile) = 0;
 
   // Same as ExecuteOnStream(), but runs this executable on multiple
   // streams. arguments[i] contains the arguments to the execution on
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 99bc0f7fee0..93af1cd995e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -299,11 +299,14 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
   return &module_globals_.emplace(executor, std::move(globals)).first->second;
 }
 
-StatusOr GpuExecutable::Execute(
+StatusOr GpuExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
-    HloExecutionProfile* hlo_execution_profile, bool block_host_until_done) {
-  se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+    std::vector> arguments,
+    HloExecutionProfile* hlo_execution_profile) {
+  se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator();
+  // Force synchronous execution if the allocator requires it.
+  const bool block_host_until_done =
+      !memory_allocator->AllowsAsynchronousDeallocation();
 
   if (GetRootValueSet().IsAmbiguous()) {
     return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -334,7 +337,9 @@ StatusOr GpuExecutable::Execute(
       if (allocation.is_entry_computation_parameter()) {
         auto param_no = allocation.parameter_number();
         se::DeviceMemoryBase buffer =
-            arguments[param_no]->buffer(allocation.param_shape_index());
+            arguments[param_no]
+                .element(allocation.param_shape_index())
+                .AsDeviceMemoryBase();
 
         // All top-level buffers and sub-buffers must have an explicit, non-null
         // pointer, except for zero-sized buffers, which may be null.
@@ -423,19 +428,17 @@ StatusOr GpuExecutable::Execute(
       }));
   TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
 
-  return std::move(shaped_buffer);
-}
-
-StatusOr GpuExecutable::ExecuteAsyncOnStream(
-    const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
-    HloExecutionProfile* hlo_execution_profile) {
-  se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
-  // Force synchronous execution if the allocator requires it.
-  bool block_host_until_done =
-      !memory_allocator->AllowsAsynchronousDeallocation();
-  return Execute(run_options, arguments, hlo_execution_profile,
-                 block_host_until_done);
+  std::vector buffers_to_free;
+  for (ShapeTree& argument : arguments) {
+    for (std::pair& buffer : argument) {
+      auto maybe_owning_buffer = buffer.second.Release();
+      if (maybe_owning_buffer) {
+        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
+      }
+    }
+  }
+  return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free),
+                         {}, {});
 }
 
 const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 66f86d768be..51e86a9f8ee 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -82,9 +82,9 @@ class GpuExecutable : public Executable {
 
   // ExecuteAsyncOnStream will fail if the compute capability of the stream
   // doesn't match the compute capability passed to this object's constructor.
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      absl::Span arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
   std::shared_ptr GetBufferAssignment() const {
@@ -92,11 +92,6 @@ class GpuExecutable : public Executable {
   }
 
  private:
-  StatusOr Execute(
-      const ServiceExecutableRunOptions* run_options,
-      absl::Span arguments,
-      HloExecutionProfile* hlo_execution_profile, bool block_host_until_done);
-
   // If `block_host_until_done` is false, execution will not block the host
   // until the kernels have completed. This is used as an optimization for
   // clients, such as Tensorflow, that use a single stream of execution for
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
index 1c5b166a801..3e82e3271bb 100644
--- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
@@ -151,7 +151,8 @@ absl::optional HloInputOutputAliasConfig::GetAliasedOutput(
 absl::optional
 HloInputOutputAliasConfig::GetAliasedParameter(
     const ShapeIndex& output_index) const {
-  CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
+  CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
+      << ToString() << " " << alias_.shape().ToString() << " " << output_index;
   return alias_.element(output_index);
 }
 
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 3073c68c975..84c7982ad10 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -89,10 +89,15 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo_evaluator",
         "//tensorflow/compiler/xla/service:hlo_execution_profile",
         "//tensorflow/compiler/xla/service:hlo_module_config",
+        "//tensorflow/compiler/xla/service:maybe_owning_device_memory",
         "//tensorflow/compiler/xla/service:shaped_buffer",
         "//tensorflow/compiler/xla/service:transfer_manager",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/core/platform:env",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:types",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 0dab86d986c..f82a439fdb0 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/interpreter/executor.h"
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/service/transfer_manager.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -39,24 +40,39 @@ namespace interpreter {
 InterpreterExecutable::InterpreterExecutable(
     std::unique_ptr hlo_module,
     std::unique_ptr evaluator)
-    : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr,
+    : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
                  /*hlo_profile_index_map=*/nullptr),
       evaluator_(std::move(evaluator)) {}
 
 InterpreterExecutable::~InterpreterExecutable() {}
 
-StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
+StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
+    std::vector> arguments,
     HloExecutionProfile* hlo_execution_profile) {
   se::Stream* stream = run_options->stream();
   se::StreamExecutor* executor = stream->parent();
   const se::Platform* platform = executor->platform();
 
+  // Convert the ShapeTree to a ShapedBuffer. We do this so we can call
+  // TransferManager methods below.
+  std::vector argument_buffers;
+  argument_buffers.reserve(arguments.size());
+  for (const ShapeTree& arg : arguments) {
+    argument_buffers.push_back(ShapedBuffer(arg.shape(), arg.shape(),
+                                            /*platform=*/nullptr,
+                                            /*device_ordinal=*/0));
+    auto in_it = arg.begin();
+    auto out_it = argument_buffers.back().buffers().begin();
+    for (; in_it != arg.end(); ++in_it, ++out_it) {
+      out_it->second = in_it->second.AsDeviceMemoryBase();
+    }
+  }
+
   VLOG(1) << "Execute " << module().name();
   if (VLOG_IS_ON(2)) {
-    for (const auto& a : arguments) {
-      VLOG(2) << "-- argument " << *a;
+    for (const auto& a : argument_buffers) {
+      VLOG(2) << "-- argument " << a;
     }
   }
 
@@ -71,7 +87,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
   // Check that the args have the right shape.
   for (int64 i = 0; i < computation->num_parameters(); ++i) {
     const auto& expected_shape = computation->parameter_instruction(i)->shape();
-    const auto& actual_shape = arguments[i]->on_device_shape();
+    const auto& actual_shape = argument_buffers[i].on_device_shape();
     if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
                                                    actual_shape)) {
       return InvalidArgument(
@@ -90,7 +106,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
   for (int64 p = 0; p < computation->num_parameters(); ++p) {
     TF_ASSIGN_OR_RETURN(Literal arg_literal,
                         transfer_manager->TransferLiteralFromDevice(
-                            run_options->stream(), *arguments[p]));
+                            run_options->stream(), argument_buffers[p]));
     arg_literals.push_back(std::move(arg_literal));
   }
 
@@ -119,7 +135,16 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
     profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
   }
 
-  return std::move(result);
+  std::vector buffers_to_free;
+  for (ShapeTree& argument : arguments) {
+    for (std::pair& buffer : argument) {
+      auto maybe_owning_buffer = buffer.second.Release();
+      if (maybe_owning_buffer) {
+        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
+      }
+    }
+  }
+  return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {});
 }
 
 /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h
index ba010de76bd..1bea6773fdd 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.h
+++ b/tensorflow/compiler/xla/service/interpreter/executable.h
@@ -46,9 +46,9 @@ class InterpreterExecutable : public Executable {
                         std::unique_ptr evaluator);
   ~InterpreterExecutable() override;
 
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      absl::Span arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile) override
       LOCKS_EXCLUDED(evaluator_lock_);
 
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
index 5fe5fea71ac..c4bf48bcc00 100644
--- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
@@ -17,7 +17,8 @@ limitations under the License.
 #include "absl/types/variant.h"
 namespace xla {
 
-tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() {
+tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase()
+    const {
   if (HasOwnership()) {
     return *absl::get(mem_);
   } else {
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
index 8edd64cf681..7d23d178130 100644
--- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
@@ -49,7 +49,7 @@ class MaybeOwningDeviceMemory {
 
   // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The
   // caller of this function is *not* responsible for freeing the memory.
-  tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase();
+  tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase() const;
 
   // Release the tensorflow::se::OwningDeviceMemory without freeing it, and
   // moves the ownership of the memory buffer from the object to the caller.

From 05ed82f0f4c00982cd274a2ebc5d0aeb759eebcb Mon Sep 17 00:00:00 2001
From: YoungSeok Yoon 
Date: Mon, 2 Dec 2019 20:43:51 -0800
Subject: [PATCH 201/279] Make the TFLite with select TF ops for iOS composable

PiperOrigin-RevId: 283472625
Change-Id: I59b4c378465043c5dfbb61e454073afadade2ac1
---
 .bazelrc                                      |  3 +-
 tensorflow/lite/experimental/ios/BUILD.apple  | 32 +++++++++++--------
 .../ios/TensorFlowLiteSelectTfOps.md          | 19 +++++++++++
 ...TensorFlowLiteSelectTfOps.podspec.template | 21 ++++++++++++
 4 files changed, 60 insertions(+), 15 deletions(-)
 create mode 100644 tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.md
 create mode 100644 tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template

diff --git a/.bazelrc b/.bazelrc
index 5fd28e867c0..451cc60fdd1 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -100,9 +100,9 @@ build --apple_platform_type=macos
 # iOS configs for each architecture and the fat binary builds.
 build:ios --apple_platform_type=ios
 build:ios --apple_bitcode=embedded --copt=-fembed-bitcode
+build:ios --copt=-Wno-c++11-narrowing
 build:ios_armv7 --config=ios
 build:ios_armv7 --cpu=ios_armv7
-build:ios_armv7 --copt -Wno-c++11-narrowing
 build:ios_arm64 --config=ios
 build:ios_arm64 --cpu=ios_arm64
 build:ios_i386 --config=ios
@@ -111,7 +111,6 @@ build:ios_x86_64 --config=ios
 build:ios_x86_64 --cpu=ios_x86_64
 build:ios_fat --config=ios
 build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64
-build:ios_fat --copt -Wno-c++11-narrowing
 
 # Config to use a mostly-static build and disable modular op registration
 # support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple
index 6ecd3d589ea..cf81057b167 100644
--- a/tensorflow/lite/experimental/ios/BUILD.apple
+++ b/tensorflow/lite/experimental/ios/BUILD.apple
@@ -26,18 +26,6 @@ ios_static_framework(
     ],
 )
 
-# bazel build -c opt --config=ios --ios_multi_cpus=armv7,arm64,x86_64 //tensorflow/lite/experimental/ios:TensorFlowLiteCWithSelectTfOps_framework
-ios_static_framework(
-    name = "TensorFlowLiteCWithSelectTfOps_framework",
-    hdrs = TFL_LIBRARY_HDRS,
-    bundle_name = "TensorFlowLiteC",
-    minimum_os_version = TFL_MINIMUM_OS_VERSION,
-    deps = [
-        ":TensorFlowLiteC",
-        "//tensorflow/lite/delegates/flex:delegate",
-    ],
-)
-
 objc_library(
     name = "TensorFlowLiteC",
     hdrs = TFL_LIBRARY_HDRS,
@@ -50,6 +38,24 @@ objc_library(
     ],
 )
 
+# This target builds the flex delegate as a separate static framework, which
+# does not include the TensorFlow Lite runtime. As this target does not contain
+# TensorFlow Lite runtime, it is intended to be linked along with the
+# TensorFlowLiteC framework above in a composable way.
+#
+# The flex delegate cannot be built for i386, so it can't be built with ios_fat
+# config.
+#
+# bazel build -c opt --config=ios --ios_multi_cpus=armv7,arm64,x86_64 //tensorflow/lite/experimental/ios:TensorFlowLiteSelectTfOps_framework
+ios_static_framework(
+    name = "TensorFlowLiteSelectTfOps_framework",
+    bundle_name = "TensorFlowLiteSelectTfOps",
+    minimum_os_version = TFL_MINIMUM_OS_VERSION,
+    deps = [
+        "//tensorflow/lite/delegates/flex:delegate",
+    ],
+)
+
 # Using this intermediate target is a workaround for a bug in bazel build rules
 # involving mixed objc_library & cc_library deps mentioned in (b/74809458).
 # When these dependencies are declared directly under the "TensorFlowLiteC"
@@ -79,6 +85,6 @@ build_test(
     ],
     targets = [
         ":TensorFlowLiteC_framework",
-        ":TensorFlowLiteCWithSelectTfOps_framework",
+        ":TensorFlowLiteSelectTfOps_framework",
     ],
 )
diff --git a/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.md b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.md
new file mode 100644
index 00000000000..525049db2b7
--- /dev/null
+++ b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.md
@@ -0,0 +1,19 @@
+# TensorFlow Lite with Select TensorFlow ops
+
+For enabling the Select TensorFlow ops for your TensorFlow Lite app, please add
+the `TensorFlowLiteSelectTfOps` pod to your Podfile, in addition to
+`TensorFlowLiteSwift` or `TensorFlowLiteObjC` pod, depending on your primary
+language.
+
+After that, you should also force load the framework from your project. Add the
+following line to the `Other Linker Flags` under your project's Build Settings
+page.
+
+```
+-force_load "$(PROJECT_DIR)/Pods/TensorFlowLiteSelectTfOps/Frameworks/TensorFlowLiteSelectTfOps.framework/TensorFlowLiteSelectTfOps"
+```
+
+Please refer to the [Select operators from TensorFlow][ops-select] guide for
+more details.
+
+[ops-select]: https://www.tensorflow.org/lite/guide/ops_select#ios
diff --git a/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template
new file mode 100644
index 00000000000..7a91e4a08ce
--- /dev/null
+++ b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template
@@ -0,0 +1,21 @@
+Pod::Spec.new do |s|
+  s.name             = 'TensorFlowLiteSelectTfOps'
+  s.version          = '${TFL_BUILD_VERSION}'
+  s.authors          = 'Google Inc.'
+  s.license          = { :type => 'Apache' }
+  s.homepage         = 'https://github.com/tensorflow/tensorflow'
+  s.source           = { :http => "${TFL_DOWNLOAD_URL}" }
+  s.summary          = 'TensorFlow Lite'
+  s.description      = <<-DESC
+
+  This pod can be used in addition to `TensorFlowLiteSwift` or
+  `TensorFlowLiteObjC` pod, in order to enable Select TensorFlow ops. The
+  resulting binary should also be force-loaded to the final app binary.
+                       DESC
+
+  s.ios.deployment_target = '9.0'
+
+  s.module_name = 'TensorFlowLiteSelectTfOps'
+  s.library = 'c++'
+  s.vendored_frameworks = 'Frameworks/TensorFlowLiteSelectTfOps.framework'
+end

From 0c3157fc648325583fe2eb62450a5b3d1e2c1eca Mon Sep 17 00:00:00 2001
From: Renjie Liu 
Date: Mon, 2 Dec 2019 20:44:44 -0800
Subject: [PATCH 202/279] change optimized quantized int8 path 4d mean to be
 integer-only.

PiperOrigin-RevId: 283472731
Change-Id: I15765dd32be9c67b0295ac68494b3328eb33f32b
---
 .../internal/optimized/integer_ops/mean.h     | 173 ++++++++++--------
 tensorflow/lite/kernels/reduce_test.cc        |  20 ++
 2 files changed, 113 insertions(+), 80 deletions(-)

diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h
index 44f7040c089..5a9d4df9aa6 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h
@@ -23,18 +23,10 @@ limitations under the License.
 namespace tflite {
 namespace optimized_integer_ops {
 
-#ifdef USE_NEON
-
-using optimized_ops::DivideSumForMeanImpl;
-using optimized_ops::RoundToNearest;
-
-#endif  // USE_NEON
-
 inline void MeanImpl(const tflite::MeanParams& op_params,
                      const RuntimeShape& input_shape, const int8_t* input_data,
-                     int32 input_zero_point, float input_scale,
+                     int32 multiplier, int32 shift, int32 bias,
                      const RuntimeShape& output_shape, int8_t* output_data,
-                     int32 output_zero_point, float output_scale,
                      int start_depth, int end_depth) {
   gemmlowp::ScopedProfilingLabel label("Mean4D/Int8/MeanImpl");
 
@@ -45,7 +37,6 @@ inline void MeanImpl(const tflite::MeanParams& op_params,
   const int output_width = output_shape.Dims(2);
   const int input_height = input_shape.Dims(1);
   const int input_width = input_shape.Dims(2);
-  const float num_elements_in_axis = input_width * input_height;
 
   TFLITE_CHECK_EQ(op_params.axis_count, 2);
   TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
@@ -53,82 +44,98 @@ inline void MeanImpl(const tflite::MeanParams& op_params,
   TFLITE_CHECK_EQ(output_height, 1);
   TFLITE_CHECK_EQ(output_width, 1);
 
-  const bool ordinary_mean =
-      (input_zero_point == output_zero_point && input_scale == output_scale);
-  float scale = 0.0f, bias = 0.0f;
-  if (!ordinary_mean) {
-    scale = input_scale / output_scale;
-    bias = -input_zero_point * scale + 0.5;
-  }
+  constexpr static int32_t kMinValue = std::numeric_limits::min();
+  constexpr static int32_t kMaxValue = std::numeric_limits::max();
 
 #ifdef USE_NEON
-  const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis);
-  // This is only an approximation as NEON does not offer division instruction.
-  const float32x4_t scale_dup = vdupq_n_f32(scale);
-  const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup);
-  float32x4_t zero_point_with_bias_dup = vdupq_n_f32(output_zero_point + bias);
+  const int32x4_t bias_dup = vdupq_n_s32(bias);
+  const int32x4_t min_dup = vdupq_n_s32(kMinValue);
+  const int32x4_t max_dup = vdupq_n_s32(kMaxValue);
 #endif  // USE_NEON
-
   for (int out_b = 0; out_b < output_batch; ++out_b) {
     int out_d = start_depth;
 #ifdef USE_NEON
 
-    for (; out_d < end_depth - 8; out_d += 8) {
-      float32x4_t temp_sum_1 = vdupq_n_f32(0);
-      float32x4_t temp_sum_2 = vdupq_n_f32(0);
+    for (; out_d <= end_depth - 16; out_d += 16) {
+      int32x4x4_t temp_sum;
+      temp_sum.val[0] = vdupq_n_s32(0);
+      temp_sum.val[1] = vdupq_n_s32(0);
+      temp_sum.val[2] = vdupq_n_s32(0);
+      temp_sum.val[3] = vdupq_n_s32(0);
       for (int in_h = 0; in_h < input_height; ++in_h) {
         for (int in_w = 0; in_w < input_width; ++in_w) {
           const int8_t* input_data_ptr =
               input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
-          int8x8_t input_data_val = vld1_s8(input_data_ptr);
-          int16x8_t input_data_val_shift = vmovl_s8(input_data_val);
-          float32x4_t input_float_1 =
-              vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift)));
-          float32x4_t input_float_2 =
-              vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift)));
-          temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1);
-          temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2);
+          int8x16_t input_data_val = vld1q_s8(input_data_ptr);
+
+          int16x8_t input_data_low_shift =
+              vmovl_s8(vget_low_s8(input_data_val));
+          int16x8_t input_data_high_shift =
+              vmovl_s8(vget_high_s8(input_data_val));
+
+          int32x4_t input_low_low =
+              vmovl_s16(vget_low_s16(input_data_low_shift));
+          int32x4_t input_high_low =
+              vmovl_s16(vget_high_s16(input_data_low_shift));
+          int32x4_t input_low_high =
+              vmovl_s16(vget_low_s16(input_data_high_shift));
+          int32x4_t input_high_high =
+              vmovl_s16(vget_high_s16(input_data_high_shift));
+
+          temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low);
+          temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low);
+          temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high);
+          temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high);
         }
       }
 
-      const float32x4_t mean_1 =
-          DivideSumForMeanImpl(temp_sum_1, num_elements_reverse, ordinary_mean,
-                               scale_dup, zero_point_with_bias_dup);
-      const float32x4_t mean_2 =
-          DivideSumForMeanImpl(temp_sum_2, num_elements_reverse, ordinary_mean,
-                               scale_dup, zero_point_with_bias_dup);
+      temp_sum = optimized_ops::MultiplyByQuantizedMultiplier4Rows(
+          temp_sum, multiplier, shift);
+
+      temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup);
+      temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup);
+      temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup);
+      temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup);
+
+      temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup);
+      temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup);
+      temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup);
+      temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup);
+
+      int16x4_t narrowed_low_low = vmovn_s32(temp_sum.val[0]);
+      int16x4_t narrowed_high_low = vmovn_s32(temp_sum.val[1]);
+      int16x4_t narrowed_low_high = vmovn_s32(temp_sum.val[2]);
+      int16x4_t narrowed_high_high = vmovn_s32(temp_sum.val[3]);
+
+      int16x8_t combined_low =
+          vcombine_s16(narrowed_low_low, narrowed_high_low);
+      int16x8_t combined_high =
+          vcombine_s16(narrowed_low_high, narrowed_high_high);
+
+      int8x8_t narrowed_low = vmovn_s16(combined_low);
+      int8x8_t narrowed_high = vmovn_s16(combined_high);
+
+      int8x16_t combined_output = vcombine_s8(narrowed_low, narrowed_high);
 
-      int32x4_t casted_mean_1 = RoundToNearest(mean_1);
-      int16x4_t narrow_range_mean_1 = vmovn_s32(casted_mean_1);
-      int32x4_t casted_mean_2 = RoundToNearest(mean_2);
-      int16x4_t narrow_range_mean_2 = vmovn_s32(casted_mean_2);
-      int16x8_t combined_mean =
-          vcombine_s16(narrow_range_mean_2, narrow_range_mean_1);
-      int8x8_t narrowed_combined_mean = vmovn_s16(combined_mean);
       int8_t* output_data_ptr =
           output_data + Offset(output_shape, out_b, 0, 0, out_d);
-      vst1_s8(output_data_ptr, narrowed_combined_mean);
+      vst1q_s8(output_data_ptr, combined_output);
     }
 #endif  // USE_NEON
 
     for (; out_d < end_depth; ++out_d) {
-      float temp_value = 0;
+      int acc = 0;
       for (int in_h = 0; in_h < input_height; ++in_h) {
         for (int in_w = 0; in_w < input_width; ++in_w) {
-          temp_value +=
-              input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
+          acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
         }
       }
 
-      temp_value = temp_value / num_elements_in_axis;
-      if (ordinary_mean) {
-        output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
-            static_cast(round(temp_value));
-      } else {
-        output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
-            static_cast(round(temp_value * scale + bias)) +
-            output_zero_point;
-      }
+      acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
+      acc += bias;
+      acc = std::min(std::max(acc, kMinValue), kMaxValue);
+      output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
+          static_cast(acc);
     }
   }
 }
@@ -136,38 +143,34 @@ inline void MeanImpl(const tflite::MeanParams& op_params,
 struct MeanWorkerTask : cpu_backend_threadpool::Task {
   MeanWorkerTask(const tflite::MeanParams& op_params,
                  const RuntimeShape& input_shape, const int8_t* input_data,
-                 int32 input_zero_point, float input_scale,
+                 int32 multiplier, int32 shift, int32 bias,
                  const RuntimeShape& output_shape, int8_t* output_data,
-                 int32 output_zero_point, float output_scale, int start_height,
-                 int end_height)
+                 int start_height, int end_height)
       : op_params(op_params),
         input_shape(input_shape),
         input_data(input_data),
-        input_zero_point(input_zero_point),
-        input_scale(input_scale),
+        multiplier(multiplier),
+        shift(shift),
+        bias(bias),
         output_shape(output_shape),
         output_data(output_data),
-        output_zero_point(output_zero_point),
-        output_scale(output_scale),
         start_height(start_height),
         end_height(end_height) {}
 
   void Run() override {
-    MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
-             output_shape, output_data, output_zero_point, output_scale,
-             start_height, end_height);
+    MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
+             output_shape, output_data, start_height, end_height);
   }
 
  private:
   const tflite::MeanParams& op_params;
   const RuntimeShape& input_shape;
   const int8_t* input_data;
-  int32 input_zero_point;
-  float input_scale;
+  int32 multiplier;
+  int32 shift;
+  int32 bias;
   const RuntimeShape& output_shape;
   int8_t* output_data;
-  int32 output_zero_point;
-  float output_scale;
   int start_height;
   int end_height;
 };
@@ -197,6 +200,18 @@ inline void Mean(const tflite::MeanParams& op_params,
   TFLITE_CHECK_EQ(output_height, 1);
   TFLITE_CHECK_EQ(output_width, 1);
 
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const float num_elements_in_axis = input_width * input_height;
+
+  int32 bias =
+      output_zero_point -
+      static_cast(input_zero_point * input_scale / output_scale);
+  float real_scale = input_scale / (num_elements_in_axis * output_scale);
+
+  int32 multiplier, shift;
+  QuantizeMultiplier(real_scale, &multiplier, &shift);
+
   constexpr int kMinDepthPerThread = 8;
   int thread_count = output_depth / kMinDepthPerThread;
   thread_count = thread_count > 0 ? thread_count : 1;
@@ -204,9 +219,8 @@ inline void Mean(const tflite::MeanParams& op_params,
       std::min(thread_count, cpu_backend_context->max_num_threads());
 
   if (capped_thread_count == 1) {
-    MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
-             output_shape, output_data, output_zero_point, output_scale, 0,
-             output_depth);
+    MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
+             output_shape, output_data, 0, output_depth);
   } else {
     // Instead parrallel for batch, we loop for the output_depth since batch
     // is typical 1.
@@ -219,9 +233,8 @@ inline void Mean(const tflite::MeanParams& op_params,
       // Try to distribute the tasks as even as possible.
       int depth_end = depth_start +
                       (output_depth - depth_start) / (capped_thread_count - i);
-      tasks.emplace_back(op_params, input_shape, input_data, input_zero_point,
-                         input_scale, output_shape, output_data,
-                         output_zero_point, output_scale, depth_start,
+      tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift,
+                         bias, output_shape, output_data, depth_start,
                          depth_end);
       depth_start = depth_end;
     }
diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc
index 12b94e2019c..2bcfedaaf9f 100644
--- a/tensorflow/lite/kernels/reduce_test.cc
+++ b/tensorflow/lite/kernels/reduce_test.cc
@@ -458,6 +458,26 @@ TEST(ConstInt8MeanOpTest, QuantizedDifferentScale) {
                   kQuantizedTolerance)));
 }
 
+TEST(ConstFloatMeanOpTest, KeepDims4DMeanLargeDepthInt8) {
+  float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
+  std::vector data = {
+      0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, 0.1, 0.1, 0.1, 0.4, 0.2, 0.2,
+      0.2, 0.9, 0.9, 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, 0.3, 0.1, 0.2,
+      0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
+      0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, 0.9, 0.9, 0.2, 0.3, 0.7, 0.7,
+      0.1, 0.1, 0.3, 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
+  MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 18}, -1.0, 1.0},
+                     {TensorType_INT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true);
+  m.QuantizeAndPopulate(m.Input(), data);
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 18}));
+  EXPECT_THAT(m.GetDequantizedOutput(),
+              ElementsAreArray(ArrayFloatNear(
+                  {0.5, 0.55, 0.25, 0.35, 0.45, 0.5, 0.25, 0.3, 0.2, 0.2, 0.1,
+                   0.15, 0.35, 0.3, 0.15, 0.2, 0.6, 0.65},
+                  kQuantizedTolerance)));
+}
+
 TEST(DynamicUint8MeanOpTest, NotKeepDims) {
   float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
   std::vector data = {1.3, -4.8, -3.6, 0.24};

From c397ed9305915dbd9b57780ab97fbb4c600cdf5f Mon Sep 17 00:00:00 2001
From: Gaurav Jain 
Date: Mon, 2 Dec 2019 20:57:29 -0800
Subject: [PATCH 203/279] Avoid heap allocation of ExecuteNodeArgs

PiperOrigin-RevId: 283473749
Change-Id: I400c44fe00aa05e4f8cdf228638b6f3194fc4154
---
 .../core/common_runtime/eager/execute.cc       | 11 ++++-------
 .../core/common_runtime/eager/execute_node.h   | 18 ++++--------------
 .../common_runtime/eager/kernel_and_device.h   |  2 ++
 3 files changed, 10 insertions(+), 21 deletions(-)

diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 6895a79f767..32fdb21c1b4 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -1040,24 +1040,21 @@ Status EagerKernelExecute(
   profiler::TraceMe activity("EagerKernelExecute",
                              profiler::TraceMeLevel::kInfo);
   std::vector outputs(1);
-  gtl::InlinedVector input_vector(op_inputs.size());
 
-  std::unique_ptr inputs;
-  TF_RETURN_IF_ERROR(ExecuteNodeArgs::CreateExecuteNodeArgs(
-      std::move(input_vector), ctx, op_inputs, &inputs));
+  ExecuteNodeArgs inputs(op_inputs.size());
+  TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs));
   // TODO(apassos) figure out how to record stats for ops which are a part of
   // functions.
-  // TODO(agarwal): change Run to take vector of handles ?
   // TODO(b/111859745): When we support recovering from kernel/device errors, we
   // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
   // device. We don't call it now because it is an unneeded overhead (it
   // acquires a lock) and we can't recover from errors anyway.
   ScopedStepContainer* container = ctx->StepContainer();
   if (container == nullptr) {
-    TF_RETURN_IF_ERROR(kernel->Run(*inputs, &outputs, cancellation_manager,
+    TF_RETURN_IF_ERROR(kernel->Run(inputs, &outputs, cancellation_manager,
                                    remote_func_params));
   } else {
-    TF_RETURN_IF_ERROR(kernel->Run(container, *inputs, &outputs,
+    TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
                                    cancellation_manager, remote_func_params));
   }
   if (graph_collector != nullptr) {
diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h
index 3fb53736078..08cecf56098 100644
--- a/tensorflow/core/common_runtime/eager/execute_node.h
+++ b/tensorflow/core/common_runtime/eager/execute_node.h
@@ -45,16 +45,12 @@ namespace tensorflow {
 
 class ExecuteNodeArgs : public EagerKernelArgs {
  public:
-  static Status CreateExecuteNodeArgs(
-      gtl::InlinedVector&& tensor_args, EagerContext* ctx,
-      const gtl::InlinedVector& op_inputs,
-      std::unique_ptr* args) {
-    args->reset(new ExecuteNodeArgs(std::move(tensor_args)));
-    return (*args)->Init(ctx, op_inputs);
-  }
-
+  explicit ExecuteNodeArgs(int count) : EagerKernelArgs(count) {}
   ~ExecuteNodeArgs() override;
 
+  Status Init(EagerContext* ctx,
+              const gtl::InlinedVector& op_inputs);
+
   bool HasRemoteInputs() const override { return has_remote_inputs_; };
 
 #if !defined(IS_MOBILE_PLATFORM)
@@ -65,12 +61,6 @@ class ExecuteNodeArgs : public EagerKernelArgs {
 #endif  // IS_MOBILE_PLATFORM
 
  private:
-  explicit ExecuteNodeArgs(gtl::InlinedVector&& tensor_args)
-      : EagerKernelArgs(std::move(tensor_args)) {}
-
-  Status Init(EagerContext* ctx,
-              const gtl::InlinedVector& op_inputs);
-
   bool has_remote_inputs_ = false;
   TensorReferenceVector protected_tensors_;
 #if !defined(IS_MOBILE_PLATFORM)
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index 395dcc98f78..04d97f2b80c 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -59,6 +59,8 @@ class EagerKernelArgs : public FunctionArgsInterface {
  public:
   EagerKernelArgs() {}
 
+  explicit EagerKernelArgs(int count) : tensor_args_(count) {}
+
   explicit EagerKernelArgs(gtl::InlinedVector&& tensor_args)
       : tensor_args_(std::move(tensor_args)) {}
 

From 05a80058380ca074a3bc354cb4765aee6ce29be6 Mon Sep 17 00:00:00 2001
From: James Keeling 
Date: Tue, 3 Dec 2019 00:13:53 -0800
Subject: [PATCH 204/279] Only warn about unoptimized datasets every 30s

This LOG message had a tendency to be spammed many times per second for a significant proportion of the graph build time for some of our users. Seeing as most users will only ever use one dataset in any given run, these errors are all the same. I therefore just output one per 30s.

PiperOrigin-RevId: 283492175
Change-Id: Idd7d210d3ed7297dce6ede5d51893dcefcd523d8
---
 tensorflow/core/framework/dataset.cc | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index f27fa75eb7d..a1625b48408 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/core/framework/variant_op_registry.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 
@@ -389,7 +390,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
     TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
     DCHECK_NE(ctx->input_list(), nullptr);
     ctx->input_list()->emplace_back((*output)->name(), std::move(t));
-    LOG(WARNING)
+    LOG_EVERY_N_SEC(WARNING, 30)
         << "Input of " << dataset->DebugString()
         << " will not be optimized because the dataset does not implement the "
            "AsGraphDefInternal() method needed to apply optimizations.";

From f0bcbfc2ae35b91088de646d0e7978ebc04a0af5 Mon Sep 17 00:00:00 2001
From: Prakalp Srivastava 
Date: Tue, 3 Dec 2019 00:25:17 -0800
Subject: [PATCH 205/279] Combine all HLO dialect exporter tests in a single
 file.

This also fixes some of the tests to use regex instead of SSA names.

PiperOrigin-RevId: 283493218
Change-Id: Icb255fd048c6f2fc1290ff830903bc9959def6b8
---
 .../mlir/xla/tests/translate/all_reduce.mlir  |  26 -
 .../xla/tests/translate/batch_norm_grad.mlir  |  15 -
 .../tests/translate/batch_norm_training.mlir  |  13 -
 .../tests/translate/binary_arithmetic.mlir    |  23 -
 .../tests/translate/binary_op_broadcast.mlir  |  26 -
 .../mlir/xla/tests/translate/broadcast.mlir   |   9 -
 .../xla/tests/translate/broadcast_in_dim.mlir |  12 -
 .../mlir/xla/tests/translate/call.mlir        |  30 -
 .../translate/call_multiple_results.mlir      |  24 -
 .../mlir/xla/tests/translate/concatenate.mlir |  17 -
 .../mlir/xla/tests/translate/conditional.mlir |   4 +-
 .../mlir/xla/tests/translate/const.mlir       |  30 -
 .../mlir/xla/tests/translate/conv.mlir        |  31 -
 .../mlir/xla/tests/translate/convert.mlir     |  10 -
 .../mlir/xla/tests/translate/copy.mlir        |  10 -
 .../tests/translate/cross_replica_sum.mlir    |  16 -
 .../mlir/xla/tests/translate/einsum.mlir      |   9 -
 .../mlir/xla/tests/translate/export.mlir      | 622 ++++++++++++++++++
 .../tests/translate/get_dimension_size.mlir   |  10 -
 .../tests/translate/get_element_tuple.mlir    |  10 -
 .../translate/{ops.hlotxt => import.hlotxt}   |   0
 .../mlir/xla/tests/translate/iota.mlir        |  11 -
 .../mlir/xla/tests/translate/pad.mlir         |  12 -
 .../mlir/xla/tests/translate/reduce.mlir      |  24 -
 .../xla/tests/translate/reduce_window.mlir    |  27 -
 .../mlir/xla/tests/translate/reshape.mlir     |  10 -
 .../mlir/xla/tests/translate/reverse.mlir     |  12 -
 .../mlir/xla/tests/translate/rng_uniform.mlir |  14 -
 .../mlir/xla/tests/translate/scatter.mlir     |  27 -
 .../mlir/xla/tests/translate/select.mlir      |  14 -
 .../tests/translate/select_and_scatter.mlir   |  34 -
 .../mlir/xla/tests/translate/slice.mlir       |  11 -
 .../mlir/xla/tests/translate/transpose.mlir   |  11 -
 .../mlir/xla/tests/translate/tuple.mlir       |  11 -
 .../mlir/xla/tests/translate/unary_ops.mlir   |  21 -
 .../mlir/xla/tests/translate/xor.mlir         |  11 -
 36 files changed, 624 insertions(+), 573 deletions(-)
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/all_reduce.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/batch_norm_grad.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/batch_norm_training.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/binary_arithmetic.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/call.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/call_multiple_results.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/concatenate.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/const.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/conv.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/convert.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/copy.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/cross_replica_sum.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir
 create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/export.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/get_dimension_size.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/get_element_tuple.mlir
 rename tensorflow/compiler/mlir/xla/tests/translate/{ops.hlotxt => import.hlotxt} (100%)
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/iota.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/pad.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/reduce_window.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/reshape.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/reverse.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/rng_uniform.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/select.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/select_and_scatter.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/slice.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/tuple.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/unary_ops.mlir
 delete mode 100644 tensorflow/compiler/mlir/xla/tests/translate/xor.mlir

diff --git a/tensorflow/compiler/mlir/xla/tests/translate/all_reduce.mlir b/tensorflow/compiler/mlir/xla/tests/translate/all_reduce.mlir
deleted file mode 100644
index 6c418799da8..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/all_reduce.mlir
+++ /dev/null
@@ -1,26 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
-  %0 = "xla_hlo.all_reduce"(%arg0) ({
-  // Perform max reduction inside the region
-  ^bb0(%lhs: tensor, %rhs: tensor):
-    %max = xla_hlo.max %lhs, %rhs : tensor
-    "xla_hlo.return"(%max) : (tensor) -> ()
-  })
-  {
-    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
-    channel_id = {
-      handle = 5 : i64,
-      type = 2 : i64
-    }
-  } : (tensor<10xf32>) -> tensor<10xf32>
-  return %0 : tensor<10xf32>
-}
-
-// CHECK:  %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
-// CHECK-LABEL:  ENTRY
-// CHECK:  %[[ARG0:.*]] = f32[10] parameter(0)
-// CHECK:  ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
-// CHECK-SAME:  channel_id=5
-// CHECK-SAME:  replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
-// CHECK-SAME:  to_apply=%[[COMPUTATION]]
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_grad.mlir b/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_grad.mlir
deleted file mode 100644
index fff194c627b..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_grad.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> {
-  %0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>>
-  return %0 : tuple, tensor<2xf32>, tensor<2xf32>>
-}
-
-// CHECK-LABEL:  ENTRY
-// CHECK:  [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
-// CHECK:  [[VAL_2:%.*]] = f32[2] parameter(1)
-// CHECK:  [[VAL_3:%.*]] = f32[2] parameter(2)
-// CHECK:  [[VAL_4:%.*]] = f32[2] parameter(3)
-// CHECK:  [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4)
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_training.mlir b/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_training.mlir
deleted file mode 100644
index d51e801b438..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_training.mlir
+++ /dev/null
@@ -1,13 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> {
-  %0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>>
-  return %0 : tuple, tensor<2xf32>, tensor<2xf32>>
-}
-
-// CHECK-LABEL:  ENTRY
-// CHECK:  [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
-// CHECK:  [[VAL_2:%.*]] = f32[2] parameter(1)
-// CHECK:  [[VAL_3:%.*]] = f32[2] parameter(2)
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/binary_arithmetic.mlir b/tensorflow/compiler/mlir/xla/tests/translate/binary_arithmetic.mlir
deleted file mode 100644
index 50f10739816..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/binary_arithmetic.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-module {
-  func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
-    // CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
-    // CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
-    // CHECK:  [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
-    %0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
-
-    // CHECK:  [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
-    %1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
-
-    // CHECK:  [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
-    %2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
-
-    // CHECK:  [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
-    %3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
-
-    // CHECK-LABEL:  ROOT
-    // CHECK-SAME:  [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
-    return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
-  }
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir
deleted file mode 100644
index 38aa4f04bad..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir
+++ /dev/null
@@ -1,26 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-// CHECK-LABEL: ENTRY %main.13 (Arg_0.1: s32[1,4], Arg_1.2: s32[2,4], Arg_2.3: s32[2,3,4]) -> s32[2,3,4] {
-func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
-  // Same rank degenerate broadcast
-  // CHECK-NEXT: %Arg_0.1 = s32[1,4] parameter(0)
-  // CHECK-NEXT: %reshape.4 = s32[4] reshape(s32[1,4] %Arg_0.1)
-  // CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4)
-  // CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1)
-  // CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2)
-  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
-
-  // Broadcast up rank
-  // CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2}
-  // CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2)
-  // CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3)
-  %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
-
-  // Broadcast up rank + degenerate broadcast
-  // CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2}
-  // CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9)
-  // CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2}
-  // CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3)
-  %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
-  return %2 : tensor<2x3x4xi32>
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir
deleted file mode 100644
index 0b64ab23d54..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir
+++ /dev/null
@@ -1,9 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-// CHECK-LABEL: ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] {
-func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
-  // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
-  // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
-  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
-  return %0 : tensor<1x2x3x4xi32>
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir
deleted file mode 100644
index ac53ba9dbbe..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
-  %result = "xla_hlo.broadcast_in_dim"(%arg0) {
-    broadcast_dimensions = dense<0> : tensor<1xi64>
-  } : (tensor<1xf32>) -> tensor<1x10xf32>
-  return %result : tensor<1x10xf32>
-}
-
-// CHECK: ENTRY %main.3 ([[ARG0:.*]]: f32[1]) -> f32[1,10] {
-// CHECK:  %[[ARG0]] = f32[1] parameter(0)
-// CHECK:  ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] %[[ARG0]]), dimensions={0}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/call.mlir b/tensorflow/compiler/mlir/xla/tests/translate/call.mlir
deleted file mode 100644
index e9cfefc308d..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/call.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
-  %0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
-  %1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
-  return %1 : tensor<4xi32>
-}
-
-func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
-  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
-  return %0 : tensor<4xi32>
-}
-
-// CHECK:  [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] {
-// CHECK:  %[[ARG_1]] = s32[4] parameter(0)
-// CHECK:  %[[ARG_2]] = s32[4] parameter(1)
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]])
-
-// CHECK:  [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] {
-// CHECK:  %[[ARG_3]] = s32[4] parameter(0)
-// CHECK:  %[[ARG_4]] = s32[4] parameter(1)
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]])
-
-// CHECK:  ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
-// CHECK:  %[[ARG]] = s32[4] parameter(0)
-// CHECK:  [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]]
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]]
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/call_multiple_results.mlir b/tensorflow/compiler/mlir/xla/tests/translate/call_multiple_results.mlir
deleted file mode 100644
index 3276cb71090..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/call_multiple_results.mlir
+++ /dev/null
@@ -1,24 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
-  %0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>)
-  return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32>
-}
-
-func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
-  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
-  %1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
-  return %0, %1 : tensor<4xi32>, tensor<4xi32>
-}
-
-// Get name of callee computation
-// CHECK:  [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) {
-
-// CHECK-LABEL:  ENTRY
-// CHECK-SAME:  [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) {
-// CHECK:  %[[ARG]] = s32[4] parameter(0)
-// CHECK:  [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]]
-// CHECK:  [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0
-// CHECK:  [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]])
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/concatenate.mlir b/tensorflow/compiler/mlir/xla/tests/translate/concatenate.mlir
deleted file mode 100644
index 593c2e2f4e6..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/concatenate.mlir
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0 : tensor<5x2xf32>,
-           %arg1 : tensor<5x5xf32>,
-           %arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
-  %result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
-    dimension = 1 : i64
-  } : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
-  return %result : tensor<5x14xf32>
-}
-
-
-// CHECK-LABEL: main
-// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0)
-// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1)
-// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2)
-// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir b/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir
index 1eb1e5ca7a5..e69d677a8cc 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir
@@ -42,13 +42,13 @@ func @main(%arg0: tensor) -> tuple> {
 
   // CHECK:   %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
   %2 = "xla_hlo.conditional"(%0, %1, %1) ( {
-  ^bb0(%arg1: tuple>):	// no predecessors
+  ^bb0(%arg1: tuple>):
     %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor
     %7 = "xla_hlo.log"(%6) : (tensor) -> tensor
     %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple>
     "xla_hlo.return"(%8) : (tuple>) -> ()
   },  {
-  ^bb0(%arg1: tuple>):	// no predecessors
+  ^bb0(%arg1: tuple>):
     %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor
     %7 = "xla_hlo.exp"(%6) : (tensor) -> tensor
     %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple>
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/const.mlir b/tensorflow/compiler/mlir/xla/tests/translate/const.mlir
deleted file mode 100644
index 42d9c5dc963..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/const.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure
-
-// CHECK-LABEL: ENTRY %main
-func @main() -> tensor<2x2x1x1xf32> {
-  // CHECK: constant.{{.*}} = s64[] constant(1)
-  %cst = constant dense<1> : tensor
-  // CHECK: constant.{{.*}} = f32[2,2,1,1]
-  // CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } }
-  %cst_0 = constant dense<
-    [[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]]
-  > : tensor<2x2x1x1xf32>
-
-  // CHECK: s32[1] constant({1})
-  %cst_1 = constant dense<1> : tensor<1xi32>
-
-  // CHECK: %[[C:.*]] = s32[] constant(1)
-  // CHECK: s32[10] broadcast(s32[] %[[C]])
-  %cst_2 = constant dense<1> : tensor<10xi32>
-
-  // CHECK: s32[4] constant({1, 2, 3, 4})
-  %cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
-
-  // CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } })
-  %cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-
-  // CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
-  %cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
-
-  return %cst_0 : tensor<2x2x1x1xf32>
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conv.mlir b/tensorflow/compiler/mlir/xla/tests/translate/conv.mlir
deleted file mode 100644
index 5cdc65b49af..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/conv.mlir
+++ /dev/null
@@ -1,31 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
-  %result = "xla_hlo.conv"(%arg0, %arg1) {
-    batch_group_count = 1 : i64,
-    dimension_numbers = {
-      input_batch_dimension = 0 : i64,
-      input_feature_dimension = 3 : i64,
-      input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
-      kernel_input_feature_dimension = 3 : i64,
-      kernel_output_feature_dimension = 2 : i64,
-      kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
-      output_batch_dimension = 0 : i64,
-      output_feature_dimension = 3 : i64,
-      output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
-    },
-    feature_group_count = 1 : i64,
-    lhs_dilation = dense<1> : tensor<2xi64>,
-    padding = dense<2> : tensor<2x2xi64>,
-    rhs_dilation = dense<1> : tensor<2xi64>,
-    window_strides = dense<1> : tensor<2xi64>
-  } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32>
-  return %result : tensor<100x28x28x1xf32>
-}
-
-// CHECK-LABEL: main
-// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0)
-// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1)
-// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]),
-// CHECK-SAME: window={size=3x3 pad=2_2x2_2},
-// CHECK-SAME: dim_labels=b01f_01oi->b01f
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/convert.mlir b/tensorflow/compiler/mlir/xla/tests/translate/convert.mlir
deleted file mode 100644
index dd839df38b2..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/convert.mlir
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
-  %0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
-  return %0 : tensor<2xf32>
-}
-
-// CHECK: ENTRY %main
-// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
-// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]])
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/copy.mlir b/tensorflow/compiler/mlir/xla/tests/translate/copy.mlir
deleted file mode 100644
index f6e5ef8fd98..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/copy.mlir
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
-  %0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
-  return %0 : tensor<2xi32>
-}
-
-// CHECK: ENTRY %main
-// CHECK: [[ARG:%.*]] = s32[2] parameter(0)
-// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]])
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/cross_replica_sum.mlir b/tensorflow/compiler/mlir/xla/tests/translate/cross_replica_sum.mlir
deleted file mode 100644
index 2e094c76516..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/cross_replica_sum.mlir
+++ /dev/null
@@ -1,16 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
-  %0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
-  %1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
-  return %1 : tensor<10xf32>
-}
-
-// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[]
-// CHECK:  ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
-
-// CHECK: ENTRY %main
-// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
-// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
-// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
-// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]]
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir b/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir
deleted file mode 100644
index e703a5cb872..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir
+++ /dev/null
@@ -1,9 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-// CHECK-LABEL: ENTRY
-func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
-  // Simple einsum is lowered to HLO dot op.
-  // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  %0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
-  return %0 : tensor<3x5xi32>
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
new file mode 100644
index 00000000000..70b48fa43c9
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -0,0 +1,622 @@
+// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+  %0 = "xla_hlo.all_reduce"(%arg0) ({
+  // Perform max reduction inside the region
+  ^bb0(%lhs: tensor, %rhs: tensor):
+    %max = xla_hlo.max %lhs, %rhs : tensor
+    "xla_hlo.return"(%max) : (tensor) -> ()
+  })
+  {
+    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+    channel_id = {
+      handle = 5 : i64,
+      type = 2 : i64
+    }
+  } : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK:  %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[10] parameter(0)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
+// CHECK-SAME:  channel_id=5
+// CHECK-SAME:  replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
+// CHECK-SAME:  to_apply=%[[COMPUTATION]]
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> {
+  %0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>>
+  return %0 : tuple, tensor<2xf32>, tensor<2xf32>>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
+// CHECK:  [[VAL_2:%.*]] = f32[2] parameter(1)
+// CHECK:  [[VAL_3:%.*]] = f32[2] parameter(2)
+// CHECK:  [[VAL_4:%.*]] = f32[2] parameter(3)
+// CHECK:  [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4)
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> {
+  %0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>>
+  return %0 : tuple, tensor<2xf32>, tensor<2xf32>>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
+// CHECK:  [[VAL_2:%.*]] = f32[2] parameter(1)
+// CHECK:  [[VAL_3:%.*]] = f32[2] parameter(2)
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
+  // CHECK:  [[VAL_1:%.*]] = s32[4] parameter(0)
+  // CHECK:  [[VAL_2:%.*]] = s32[4] parameter(1)
+  // CHECK:  [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
+  %0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
+
+  // CHECK:  [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
+  %1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
+
+  // CHECK:  [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
+  %2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
+
+  // CHECK:  [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
+  %3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
+
+  // CHECK-LABEL:  ROOT
+  // CHECK-SAME:  [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
+  return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
+  // Same rank degenerate broadcast
+  // CHECK:  [[ARG_0:%.*]] = s32[1,4] parameter(0)
+  // CHECK-NEXT:  [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]])
+  // CHECK-NEXT:  [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]])
+  // CHECK-NEXT:  [[ARG_1:%.*]] = s32[2,4] parameter(1)
+  // CHECK-NEXT:  s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]])
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
+
+  // Broadcast up rank
+  // CHECK-NEXT:  [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2}
+  // CHECK-NEXT:  [[ARG_2:%.*]] = s32[2,3,4] parameter(2)
+  // CHECK-NEXT:  s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]])
+  %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
+
+  // Broadcast up rank + degenerate broadcast
+  // CHECK-NEXT:  [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2}
+  // CHECK-NEXT:  [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]])
+  // CHECK-NEXT:  [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2}
+  // CHECK-LABEL:  ROOT
+  // CHECK-SAME:  s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]])
+  %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
+  return %2 : tensor<2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
+  // CHECK:  [[ARG:%.*]] = s32[4] parameter(0)
+  // CHECK-NEXT:  ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3}
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
+  return %0 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
+  %result = "xla_hlo.broadcast_in_dim"(%arg0) {
+    broadcast_dimensions = dense<0> : tensor<1xi64>
+  } : (tensor<1xf32>) -> tensor<1x10xf32>
+  return %result : tensor<1x10xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  [[ARG:%.*]] = f32[1] parameter(0)
+// CHECK:  ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] [[ARG]]), dimensions={0}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+  %0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  return %1 : tensor<4xi32>
+}
+
+func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+}
+
+// CHECK:  [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] {
+// CHECK:  %[[ARG_1]] = s32[4] parameter(0)
+// CHECK:  %[[ARG_2]] = s32[4] parameter(1)
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]])
+
+// CHECK:  [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] {
+// CHECK:  %[[ARG_3]] = s32[4] parameter(0)
+// CHECK:  %[[ARG_4]] = s32[4] parameter(1)
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]])
+
+// CHECK:  ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
+// CHECK:  %[[ARG]] = s32[4] parameter(0)
+// CHECK:  [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]]
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]]
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
+  %0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>)
+  return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32>
+}
+
+func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  return %0, %1 : tensor<4xi32>, tensor<4xi32>
+}
+
+// Get name of callee computation
+// CHECK:  [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) {
+
+// CHECK-LABEL:  ENTRY
+// CHECK-SAME:  [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) {
+// CHECK:  %[[ARG]] = s32[4] parameter(0)
+// CHECK:  [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]]
+// CHECK:  [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0
+// CHECK:  [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]])
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0 : tensor<5x2xf32>,
+           %arg1 : tensor<5x5xf32>,
+           %arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
+  %result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
+    dimension = 1 : i64
+  } : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
+  return %result : tensor<5x14xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[5,2] parameter(0)
+// CHECK:  %[[ARG1:.*]] = f32[5,5] parameter(1)
+// CHECK:  %[[ARG2:.*]] = f32[5,7] parameter(2)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main() -> tensor<2x2x1x1xf32> {
+  // CHECK:  constant.{{.*}} = s64[] constant(1)
+  %cst = constant dense<1> : tensor
+  // CHECK:  constant.{{.*}} = f32[2,2,1,1]
+  // CHECK-SAME:  { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } }
+  %cst_0 = constant dense<
+    [[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]]
+  > : tensor<2x2x1x1xf32>
+
+  // CHECK:  s32[1] constant({1})
+  %cst_1 = constant dense<1> : tensor<1xi32>
+
+  // CHECK:  %[[C:.*]] = s32[] constant(1)
+  // CHECK:  s32[10] broadcast(s32[] %[[C]])
+  %cst_2 = constant dense<1> : tensor<10xi32>
+
+  // CHECK:  s32[4] constant({1, 2, 3, 4})
+  %cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+
+  // CHECK:  s32[2,2] constant({ { 1, 2 }, { 3, 4 } })
+  %cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+
+  // CHECK:  s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
+  %cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
+
+  return %cst_0 : tensor<2x2x1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
+  %result = "xla_hlo.conv"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 3 : i64,
+      input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+      kernel_input_feature_dimension = 3 : i64,
+      kernel_output_feature_dimension = 2 : i64,
+      kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 3 : i64,
+      output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+    },
+    feature_group_count = 1 : i64,
+    lhs_dilation = dense<1> : tensor<2xi64>,
+    padding = dense<2> : tensor<2x2xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32>
+  return %result : tensor<100x28x28x1xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[100,26,26,32] parameter(0)
+// CHECK:  %[[ARG1:.*]] = f32[3,3,1,32] parameter(1)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]),
+// CHECK-SAME:  window={size=3x3 pad=2_2x2_2},
+// CHECK-SAME:  dim_labels=b01f_01oi->b01f
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
+  %0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG:.*]] = s32[2] parameter(0)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]])
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+  %0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+  return %0 : tensor<2xi32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  [[ARG:%.*]] = s32[2] parameter(0)
+// CHECK:  ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]])
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+  %0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
+  %1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
+  return %1 : tensor<10xf32>
+}
+
+// CHECK:  %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[]
+// CHECK:  ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[10] parameter(0)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
+// CHECK-SAME:  replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
+// CHECK-SAME:  to_apply=%[[SUM_COMPUTATION]]
+
+// -----
+
+// CHECK-LABEL: HloModule
+func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
+  // Simple einsum is lowered to HLO dot op.
+  // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  %0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
+  return %0 : tensor<3x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg: tensor<4x2xf32>) -> tensor {
+  %0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor
+  return %0 : tensor
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  [[ARG:%.*]] = f32[4,2] parameter(0)
+// CHECK:  s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tuple, tensor>) -> tensor {
+  %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor
+  return %0 : tensor
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = (f32[], s32[]) parameter(0)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main() -> tensor<1x10xf32> {
+  %result = "xla_hlo.iota"() {
+    iota_dimension = 1 : i64
+  } : () -> tensor<1x10xf32>
+  return %result : tensor<1x10xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> {
+  %0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32>
+  return %0 : tensor<13x19xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  [[ARG:%.*]] = f32[4,6] parameter(0)
+// CHECK:  [[PADDING_VAL:%.*]] = f32[] parameter(1)
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor, %arg3 : tensor) -> (tensor<1xf32>, tensor<1xi32>) {
+  %result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
+    ^bb0(%fa: tensor, %ia : tensor, %fb: tensor, %ib: tensor):   // no predecessors
+      %fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor, tensor) -> tensor
+      %imax = "xla_hlo.max"(%ia, %ib) {} : (tensor, tensor) -> tensor
+      "xla_hlo.return"(%fmax, %imax) : (tensor, tensor) -> ()
+    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>)
+  return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
+}
+
+// CHECK:  %[[REGION:region_[0-9]+]]
+// CHECK-SAME:  ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
+// CHECK:  %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
+// CHECK:  %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
+// CHECK:  ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
+
+// CHECK-LABEL:  ENTRY
+// CHECK-SAME:  ([[ARG0:.*]]: f32[1,10], [[ARG1:.*]]: s32[1,10], [[ARG2:.*]]: f32[], [[ARG3:.*]]: s32[]) -> (f32[1], s32[1])
+// CHECK:  %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %[[ARG0]], s32[1,10] %[[ARG1]], f32[] %[[ARG2]], s32[] %[[ARG3]]), dimensions={1}, to_apply=%[[REGION]]
+// CHECK:  %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
+// CHECK:  %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
+// CHECK:  ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
+  %0 = xla_hlo.constant dense<-2147483648> : tensor
+  %1 = "xla_hlo.reduce_window"(%arg0, %0) ( {
+  ^bb0(%arg1: tensor, %arg2: tensor):	// no predecessors
+    %2 = xla_hlo.max %arg1, %arg2 : tensor
+    "xla_hlo.return"(%2) : (tensor) -> ()
+  }) {
+    window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>,
+    padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>,
+    base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>,
+    window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>
+  } : (tensor<2x17x31x7xi32>, tensor) -> tensor<2x3x5x7xi32>
+  return %1 : tensor<2x3x5x7xi32>
+}
+
+// CHECK:  %[[MAX_COMPUTATION:.*]] ([[ARG0:.*]]: s32[], [[ARG1:.*]]: s32[]) -> s32[]
+// CHECK:  ROOT %[[RESULT:.*]] = s32[] maximum(s32[] %[[ARG0]], s32[] %[[ARG1]])
+
+// CHECK-LABEL:  ENTRY
+// CHECK-DAG:  %[[ARG0:.*]] = s32[2,17,31,7] parameter(0)
+// CHECK-DAG:  %[[INIT:.*]] = s32[] constant(-2147483648)
+// CHECK:  ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(s32[2,17,31,7] %[[ARG0]], s32[] %constant.2),
+// CHECK-SAME:  window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1},
+// CHECK-SAME:  to_apply=%[[MAX_COMPUTATION]]
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
+  %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
+  return %0 : tensor<1x2xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[2] parameter(0)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[1,2] reshape(f32[2] %[[ARG0]])
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
+  %result = "xla_hlo.reverse"(%arg0) {
+    dimensions = dense<[1,2]> : tensor<2xi64>
+  } : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
+  return %result : tensor<10x11x12x13xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[10,11,12,13] parameter(0)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(f32[10,11,12,13] %[[ARG0]]), dimensions={1,2}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main() -> tensor<2x3x5xf32> {
+  %0 = xla_hlo.constant dense<0.000000e+00> : tensor
+  %1 = xla_hlo.constant dense<1.000000e+00> : tensor
+  %2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
+  %3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32>
+  return %3 : tensor<2x3x5xf32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK-DAG:  %[[A:.*]] = f32[] constant(0)
+// CHECK-DAG:  %[[B:.*]] = f32[] constant(1)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[A]], f32[] %[[B]]), distribution=rng_uniform
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
+  %0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
+  ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors
+    %add = xla_hlo.add %lhs, %rhs : tensor
+    "xla_hlo.return"(%add) : (tensor) -> ()
+  }) {
+    scatter_dimension_numbers = {
+      update_window_dims = dense<[1]> : tensor<1xi64>,
+      inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
+      scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
+      index_vector_dim = 1 : i64
+    },
+    indices_are_sorted = true,
+    unique_indices = true
+  } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32>
+  return %0 : tensor<200x100x300xf32>
+}
+
+// CHECK:  [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
+// CHECK-LABEL:  ENTRY
+// CHECK:  [[VAL_1:%.*]] = f32[200,100,300] parameter(0)
+// CHECK:  [[VAL_2:%.*]] = s32[10,2] parameter(1)
+// CHECK:  [[VAL_3:%.*]] = f32[10,300] parameter(2)
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]]
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  // CHECK:  %[[ARG0:.*]] = pred[] parameter(0)
+  // CHECK:  %[[COND:.*]] = pred[2,3] broadcast(pred[] %[[ARG0]]), dimensions={}
+  // CHECK:  %[[ARG1:.*]] = s32[2,3] parameter(1)
+  // CHECK:  %[[ARG2:.*]] = s32[2,3] parameter(2)
+
+  // CHECK:  ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]])
+  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  return %0 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
+  %0 = xla_hlo.constant dense<0.000000e+00> : tensor
+  %1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( {
+  ^bb0(%arg3: tensor, %arg4: tensor):	// no predecessors
+    %2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%2) : (tensor) -> ()
+  },  {
+  ^bb0(%arg3: tensor, %arg4: tensor):	// no predecessors
+    %2 = xla_hlo.add %arg3, %arg4 : tensor
+    "xla_hlo.return"(%2) : (tensor) -> ()
+  }) {
+    window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
+  } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32>
+  return %1 : tensor<10x24x24x64xf32>
+}
+
+// CHECK:  %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
+// CHECK:  ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE
+
+// CHECK:  %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
+// CHECK:  ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[10,24,24,64] parameter(0)
+// CHECK:  %[[ARG1:.*]] = f32[10,12,12,64] parameter(1)
+// CHECK:  %[[INIT:.*]] = f32[] constant(0)
+
+// CHECK:  ROOT %[[RESULT:.*]] = f32[10,24,24,64]
+// CHECK-SAME:  select-and-scatter(f32[10,24,24,64] %[[ARG0]], f32[10,12,12,64] %[[ARG1]], f32[] %[[INIT]]),
+// CHECK-SAME:  window={size=1x2x2x1 stride=1x2x2x1},
+// CHECK-SAME:  select=%[[SELECT_COMPUTATION]], scatter=%[[SCATTER_COMPUTATION]]
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
+  %0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+  return %0 : tensor<1x2xi32>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  [[ARG:%.*]] = s32[3,4] parameter(0)
+// CHECK-LABEL:  ROOT
+// CHECK-SAME:  s32[1,2] slice(s32[3,4] [[ARG]]), slice={[1:2:1], [0:4:2]}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
+  // CHECK:  [[ARG:%.*]] = s32[1,2,3,4] parameter(0)
+
+  // CHECK-NEXT:  ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2}
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+  return %0 : tensor<2x1x4x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor, %arg1 : tensor) -> tuple, tensor> {
+  %result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor, tensor) -> tuple, tensor>
+  return %result : tuple, tensor>
+}
+
+// CHECK-LABEL:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[] parameter(0)
+// CHECK:  %[[ARG1:.*]] = s32[] parameter(1)
+// CHECK:  ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(f32[] %[[ARG0]], s32[] %[[ARG1]])
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) {
+  // CHECK:  [[ARG_F32:%.*]] = f32[4] parameter(0)
+  // CHECK:  [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]])
+  %expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
+
+  // CHECK:  [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]])
+  %log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
+
+  // CHECK:  [[ARG_I32:%.*]] = s32[4] parameter(1)
+  // CHECK:  [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]])
+  %not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
+
+  // CHECK:  [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]])
+  %popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
+
+  return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
+  // CHECK:  [[VAL_1:%.*]] = pred[4] parameter(0)
+  // CHECK:  [[VAL_2:%.*]] = pred[4] parameter(1)
+  %0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1>
+  // CHECK:  ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
+  return %0 : tensor<4xi1>
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/get_dimension_size.mlir b/tensorflow/compiler/mlir/xla/tests/translate/get_dimension_size.mlir
deleted file mode 100644
index 44ff3f144f6..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/get_dimension_size.mlir
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg: tensor<4x2xf32>) -> tensor {
-  %0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor
-  return %0 : tensor
-}
-
-// CHECK-LABEL: ENTRY
-// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
-// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/get_element_tuple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/get_element_tuple.mlir
deleted file mode 100644
index 8897a6fab33..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/get_element_tuple.mlir
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tuple, tensor>) -> tensor {
-  %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor
-  return %0 : tensor
-}
-
-// CHECK-LABEL: main
-// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0)
-// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
similarity index 100%
rename from tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt
rename to tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/iota.mlir b/tensorflow/compiler/mlir/xla/tests/translate/iota.mlir
deleted file mode 100644
index e7df347a734..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/iota.mlir
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main() -> tensor<1x10xf32> {
-  %result = "xla_hlo.iota"() {
-    iota_dimension = 1 : i64
-  } : () -> tensor<1x10xf32>
-  return %result : tensor<1x10xf32>
-}
-
-// CHECK-LABEL:main
-// CHECK:  ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/pad.mlir b/tensorflow/compiler/mlir/xla/tests/translate/pad.mlir
deleted file mode 100644
index d4fba830403..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/pad.mlir
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> {
-  %0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32>
-  return %0 : tensor<13x19xf32>
-}
-
-// CHECK-LABEL:  ENTRY
-// CHECK:  [[ARG:%.*]] = f32[4,6] parameter(0)
-// CHECK:  [[PADDING_VAL:%.*]] = f32[] parameter(1)
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir
deleted file mode 100644
index db16a2219cc..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir
+++ /dev/null
@@ -1,24 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor, %arg3 : tensor) -> (tensor<1xf32>, tensor<1xi32>) {
-  %result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
-    ^bb0(%fa: tensor, %ia : tensor, %fb: tensor, %ib: tensor):   // no predecessors
-      %fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor, tensor) -> tensor
-      %imax = "xla_hlo.max"(%ia, %ib) {} : (tensor, tensor) -> tensor
-      "xla_hlo.return"(%fmax, %imax) : (tensor, tensor) -> ()
-    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>)
-  return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
-}
-
-// CHECK: %[[REGION:region_[0-9]+]]
-// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
-// CHECK:  %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
-// CHECK:  %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
-// CHECK:  ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
-
-// CHECK: ENTRY %main
-// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG0:.*]]: s32[1,10], [[ARG0:.*]]: f32[], [[ARG0:.*]]: s32[]) -> (f32[1], s32[1])
-// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %Arg_0.1, s32[1,10] %Arg_1.2, f32[] %Arg_2.3, s32[] %Arg_3.4), dimensions={1}, to_apply=%[[REGION]]
-// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
-// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
-// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce_window.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reduce_window.mlir
deleted file mode 100644
index 4ef1d1a6057..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/reduce_window.mlir
+++ /dev/null
@@ -1,27 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
-  %0 = xla_hlo.constant dense<-2147483648> : tensor
-  %1 = "xla_hlo.reduce_window"(%arg0, %0) ( {
-  ^bb0(%arg1: tensor, %arg2: tensor):	// no predecessors
-    %2 = xla_hlo.max %arg1, %arg2 : tensor
-    "xla_hlo.return"(%2) : (tensor) -> ()
-  }) {
-    window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
-    window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>,
-    padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>,
-    base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>,
-    window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>
-  } : (tensor<2x17x31x7xi32>, tensor) -> tensor<2x3x5x7xi32>
-  return %1 : tensor<2x3x5x7xi32>
-}
-
-// CHECK: %[[MAX_COMPUTATION:.*]] ([[ARG0:.*]]: s32[], [[ARG1:.*]]: s32[]) -> s32[]
-// ROOT %[[RESULT:.*]] = s32[] maximum(s32[] %[[ARG0]], s32[] %[[ARG1]])
-
-// CHECK: ENTRY %main
-// CHECK-DAG: %[[ARG0:.*]] = s32[2,17,31,7] parameter(0)
-// CHECK-DAG: %[[INIT:.*]] = s32[] constant(-2147483648)
-// CHECK: ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(s32[2,17,31,7] %[[ARG0]], s32[] %constant.2),
-// CHECK-SAME: window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1},
-// CHECK-SAME: to_apply=%[[MAX_COMPUTATION]]
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reshape.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reshape.mlir
deleted file mode 100644
index b0bb8fedb74..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/reshape.mlir
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
-  %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
-  return %0 : tensor<1x2xf32>
-}
-
-// CHECK: ENTRY %main
-// CHECK: %[[ARG0:.*]] = f32[2] parameter(0)
-// CHECK: ROOT %[[RESULT:.*]] = f32[1,2] reshape(f32[2] %[[ARG0]])
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reverse.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reverse.mlir
deleted file mode 100644
index b3393952ed6..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/reverse.mlir
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
-  %result = "xla_hlo.reverse"(%arg0) {
-    dimensions = dense<[1,2]> : tensor<2xi64>
-  } : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
-  return %result : tensor<10x11x12x13xf32>
-}
-
-// CHECK-LABEL: main
-// CHECK: %[[ARG0:.*]] = f32[10,11,12,13] parameter(0)
-// CHECK: ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(f32[10,11,12,13] %[[ARG0]]), dimensions={1,2}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/rng_uniform.mlir b/tensorflow/compiler/mlir/xla/tests/translate/rng_uniform.mlir
deleted file mode 100644
index 505d6b43b06..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/rng_uniform.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main() -> tensor<2x3x5xf32> {
-  %0 = xla_hlo.constant dense<0.000000e+00> : tensor
-  %1 = xla_hlo.constant dense<1.000000e+00> : tensor
-  %2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
-  %3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32>
-  return %3 : tensor<2x3x5xf32>
-}
-
-// CHECK: ENTRY %main
-// CHECK-DAG: %[[A:.*]] = f32[] constant(0)
-// CHECK-DAG: %[[B:.*]] = f32[] constant(1)
-// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[A]], f32[] %[[B]]), distribution=rng_uniform
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir b/tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir
deleted file mode 100644
index 227a45bab18..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir
+++ /dev/null
@@ -1,27 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
-  %0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
-  ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors
-    %add = xla_hlo.add %lhs, %rhs : tensor
-    "xla_hlo.return"(%add) : (tensor) -> ()
-  }) {
-    scatter_dimension_numbers = {
-      update_window_dims = dense<[1]> : tensor<1xi64>,
-      inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
-      scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
-      index_vector_dim = 1 : i64
-    },
-    indices_are_sorted = true,
-    unique_indices = true
-  } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32>
-  return %0 : tensor<200x100x300xf32>
-}
-
-// CHECK:  [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
-// CHECK-LABEL:  ENTRY
-// CHECK:  [[VAL_1:%.*]] = f32[200,100,300] parameter(0)
-// CHECK:  [[VAL_2:%.*]] = s32[10,2] parameter(1)
-// CHECK:  [[VAL_3:%.*]] = f32[10,300] parameter(2)
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]]
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir b/tensorflow/compiler/mlir/xla/tests/translate/select.mlir
deleted file mode 100644
index e4cc1b3babd..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure
-
-// CHECK-LABEL: ENTRY %main
-func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  // CHECK: %[[ARG0:.*]] = pred[] parameter(0)
-  // CHECK: %[[COND:.*]] = pred[2,3] broadcast(pred[] %Arg_0.1), dimensions={}
-  // CHECK: %[[ARG1:.*]] = s32[2,3] parameter(1)
-  // CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2)
-
-  // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]])
-  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
-  return %0 : tensor<2x3xi32>
-}
-
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select_and_scatter.mlir b/tensorflow/compiler/mlir/xla/tests/translate/select_and_scatter.mlir
deleted file mode 100644
index 4a8d3bbfcf3..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/select_and_scatter.mlir
+++ /dev/null
@@ -1,34 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
-  %0 = xla_hlo.constant dense<0.000000e+00> : tensor
-  %1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( {
-  ^bb0(%arg3: tensor, %arg4: tensor):	// no predecessors
-    %2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor, tensor) -> tensor
-    "xla_hlo.return"(%2) : (tensor) -> ()
-  },  {
-  ^bb0(%arg3: tensor, %arg4: tensor):	// no predecessors
-    %2 = xla_hlo.add %arg3, %arg4 : tensor
-    "xla_hlo.return"(%2) : (tensor) -> ()
-  }) {
-    window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
-    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
-  } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32>
-  return %1 : tensor<10x24x24x64xf32>
-}
-
-// CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
-// CHECK:   ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE
-
-// CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
-// CHECK:   ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
-
-// CHECK: ENTRY %main
-// CHECK:   %[[ARG0:.*]] = f32[10,24,24,64] parameter(0)
-// CHECK:   %[[ARG1:.*]] = f32[10,12,12,64] parameter(1)
-// CHECK:   %[[INIT:.*]] = f32[] constant(0)
-
-// CHECK:   ROOT %[[RESULT:.*]] = f32[10,24,24,64]
-// CHECK-SAME: select-and-scatter(f32[10,24,24,64] %[[ARG0]], f32[10,12,12,64] %[[ARG1]], f32[] %[[INIT]]),
-// CHECK-SAME: window={size=1x2x2x1 stride=1x2x2x1},
-// CHECK-SAME: select=%[[SELECT_COMPUTATION]], scatter=%[[SCATTER_COMPUTATION]]
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/slice.mlir b/tensorflow/compiler/mlir/xla/tests/translate/slice.mlir
deleted file mode 100644
index 3f31a008c1c..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/slice.mlir
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
-  %0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
-  return %0 : tensor<1x2xi32>
-}
-
-// CHECK-LABEL:  ENTRY
-// CHECK:  [[ARG:%.*]] = s32[3,4] parameter(0)
-// CHECK-LABEL:  ROOT
-// CHECK-SAME:  s32[1,2] slice(s32[3,4] [[ARG]]), slice={[1:2:1], [0:4:2]}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir b/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir
deleted file mode 100644
index 77048e6c902..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-// CHECK-LABEL: ENTRY %main
-func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
-  // CHECK-NEXT: %Arg_0.1 = s32[1,2,3,4] parameter(0)
-
-  // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
-  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
-  return %0 : tensor<2x1x4x3xi32>
-}
-
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/tuple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/tuple.mlir
deleted file mode 100644
index 5024a66dfe6..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/tuple.mlir
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-func @main(%arg0: tensor, %arg1 : tensor) -> tuple, tensor> {
-  %result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor, tensor) -> tuple, tensor>
-  return %result : tuple, tensor>
-}
-
-// CHECK-LABEL: main
-// CHECK: %[[ARG0:.*]] = f32[] parameter(0)
-// CHECK: %[[ARG1:.*]] = s32[] parameter(1)
-// CHECK: ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(f32[] %[[ARG0]], s32[] %[[ARG1]])
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/unary_ops.mlir b/tensorflow/compiler/mlir/xla/tests/translate/unary_ops.mlir
deleted file mode 100644
index c4138010543..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/unary_ops.mlir
+++ /dev/null
@@ -1,21 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-module {
-  func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) {
-    // CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0)
-    // CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]])
-    %expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
-
-    // CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]])
-    %log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
-
-    // CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1)
-    // CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]])
-    %not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
-
-    // CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]])
-    %popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
-
-    return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>
-  }
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir b/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir
deleted file mode 100644
index 3ad79d633c7..00000000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
-
-module {
-  func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
-    // CHECK:    [[VAL_1:%.*]] = pred[4] parameter(0)
-    // CHECK:    [[VAL_2:%.*]] = pred[4] parameter(1)
-    %0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1>
-    // CHECK:   ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
-    return %0 : tensor<4xi1>
-  }
-}

From c1bf955a4fb3ee310d2de37e6a944c4893390d7b Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 00:26:13 -0800
Subject: [PATCH 206/279] Add linkage support to LLVMFuncOp

A recent commit introduced the Linkage attribute to the LLVM dialect and used
it in the Global Op. Also use it in LLVMFuncOp. As per LLVM Language Reference,
if the linkage attribute is omitted, the function is assumed to have external
linkage.

PiperOrigin-RevId: 283493299
Change-Id: I827f7305290bc21a28a6b314e4d796809f6cd9db
---
 third_party/mlir/g3doc/Dialects/LLVM.md       |   5 +
 .../include/mlir/Dialect/LLVMIR/LLVMOps.td    |  17 +-
 .../GPUCommon/OpToFuncCallLowering.h          |   2 +-
 .../StandardToLLVM/ConvertStandardToLLVM.cpp  |   6 +-
 .../lib/Dialect/LLVMIR/IR/LLVMDialect.cpp     | 166 +++++++++++++-----
 .../mlir/lib/IR/FunctionImplementation.cpp    |   3 +-
 6 files changed, 146 insertions(+), 53 deletions(-)

diff --git a/third_party/mlir/g3doc/Dialects/LLVM.md b/third_party/mlir/g3doc/Dialects/LLVM.md
index ed0cad2df1f..9791352aa56 100644
--- a/third_party/mlir/g3doc/Dialects/LLVM.md
+++ b/third_party/mlir/g3doc/Dialects/LLVM.md
@@ -72,6 +72,11 @@ llvm.func @foo(%arg0: !llvm.i64) {
   llvm.return
 }
 
+// A function with `internal` linkage.
+llvm.func internal @internal_func() {
+  llvm.return
+}
+
 ```
 
 ### LLVM IR operations
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 324937a5c6d..573542ba838 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -583,9 +583,12 @@ def LLVM_GlobalOp
   let verifier = "return ::verify(*this);";
 }
 
-def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
-      [NativeOpTrait<"IsIsolatedFromAbove">, NativeOpTrait<"FunctionLike">,
-       Symbol]> {
+def LLVM_LLVMFuncOp
+    : LLVM_ZeroResultOp<"func",
+                        [NativeOpTrait<"IsIsolatedFromAbove">,
+                         NativeOpTrait<"FunctionLike">, Symbol]>,
+      Arguments<(ins DefaultValuedAttr:$linkage)> {
   let summary = "LLVM dialect function, has wrapped LLVM IR function type";
 
   let regions = (region AnyRegion:$body);
@@ -594,7 +597,8 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
 
   let builders = [
     OpBuilder<"Builder *builder, OperationState &result, StringRef name, "
-              "LLVMType type, ArrayRef attrs = {}, "
+              "LLVMType type, LLVM::Linkage linkage = LLVM::Linkage::External, "
+              "ArrayRef attrs = {}, "
               "ArrayRef argAttrs = {}">
   ];
 
@@ -627,10 +631,7 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
 
   let verifier = [{ return ::verify(*this); }];
   let printer = [{ printLLVMFuncOp(p, *this); }];
-  let parser = [{
-    return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/true,
-                                     buildLLVMFunctionType);
-  }];
+  let parser = [{ return parseLLVMFuncOp(parser, result); }];
 }
 
 def LLVM_NullOp
diff --git a/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index d681e4c86ea..e06e88b92f1 100644
--- a/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -97,7 +97,7 @@ private:
       return llvm::cast(*funcOp);
 
     mlir::OpBuilder b(op->getParentOfType());
-    return b.create(op->getLoc(), funcName, funcType, llvm::None);
+    return b.create(op->getLoc(), funcName, funcType);
   }
 
   const std::string f32Func;
diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index d226766a3fc..2db02db7f0c 100644
--- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -443,9 +443,11 @@ struct FuncOpConversion : public LLVMLegalizationPattern {
       attributes.push_back(attr);
     }
 
-    // Create an LLVM funcion.
+    // Create an LLVM funcion, use external linkage by default until MLIR
+    // functions have linkage.
     auto newFuncOp = rewriter.create(
-        op->getLoc(), funcOp.getName(), llvmType, attributes);
+        op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External,
+        attributes);
     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                 newFuncOp.end());
 
diff --git a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a8c676ff696..00911012c1d 100644
--- a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -862,6 +862,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
 // Builder, printer and verifier for LLVM::GlobalOp.
 //===----------------------------------------------------------------------===//
 
+/// Returns the name used for the linkge attribute. This *must* correspond to
+/// the name of the attribute in ODS.
+static StringRef getLinkageAttrName() { return "linkage"; }
+
 void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
                      bool isConstant, Linkage linkage, StringRef name,
                      Attribute value, ArrayRef attrs) {
@@ -872,52 +876,46 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
     result.addAttribute("constant", builder->getUnitAttr());
   if (value)
     result.addAttribute("value", value);
-  result.addAttribute(
-      "linkage", builder->getI64IntegerAttr(static_cast(linkage)));
+  result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr(
+                                                static_cast(linkage)));
   result.attributes.append(attrs.begin(), attrs.end());
   result.addRegion();
 }
 
-// Prints the keyword for the linkage type using the printer.
-static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) {
+// Returns the textual representation of the given linkage.
+static StringRef linkageToStr(LLVM::Linkage linkage) {
   switch (linkage) {
   case LLVM::Linkage::Private:
-    p << "private";
-    return;
+    return "private";
   case LLVM::Linkage::Internal:
-    p << "internal";
-    return;
+    return "internal";
   case LLVM::Linkage::AvailableExternally:
-    p << "available_externally";
-    return;
+    return "available_externally";
   case LLVM::Linkage::Linkonce:
-    p << "linkonce";
-    return;
+    return "linkonce";
   case LLVM::Linkage::Weak:
-    p << "weak";
-    return;
+    return "weak";
   case LLVM::Linkage::Common:
-    p << "common";
-    return;
+    return "common";
   case LLVM::Linkage::Appending:
-    p << "appending";
-    return;
+    return "appending";
   case LLVM::Linkage::ExternWeak:
-    p << "extern_weak";
-    return;
+    return "extern_weak";
   case LLVM::Linkage::LinkonceODR:
-    p << "linkonce_odr";
-    return;
+    return "linkonce_odr";
   case LLVM::Linkage::WeakODR:
-    p << "weak_odr";
-    return;
+    return "weak_odr";
   case LLVM::Linkage::External:
-    p << "external";
-    return;
+    return "external";
   }
   llvm_unreachable("unknown linkage type");
 }
 
+// Prints the keyword for the linkage type using the printer.
+static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) {
+  p << linkageToStr(linkage);
+}
+
 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
   p << op.getOperationName() << ' ';
   printLinkage(p, op.linkage());
@@ -931,7 +929,7 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
   p << ')';
   p.printOptionalAttrDict(op.getAttrs(),
                           {SymbolTable::getSymbolAttrName(), "type", "constant",
-                           "value", "linkage"});
+                           "value", getLinkageAttrName()});
 
   // Print the trailing type unless it's a string global.
   if (op.getValueOrNull().dyn_cast_or_null())
@@ -970,7 +968,8 @@ static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser,
                "weak_odr", "external"});
   if (index == -1)
     return failure();
-  result.addAttribute("linkage", parser.getBuilder().getI64IntegerAttr(index));
+  result.addAttribute(getLinkageAttrName(),
+                      parser.getBuilder().getI64IntegerAttr(index));
   return success();
 }
 
@@ -1118,12 +1117,15 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
 //===----------------------------------------------------------------------===//
 
 void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name,
-                       LLVMType type, ArrayRef attrs,
+                       LLVMType type, LLVM::Linkage linkage,
+                       ArrayRef attrs,
                        ArrayRef argAttrs) {
   result.addRegion();
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder->getStringAttr(name));
   result.addAttribute("type", TypeAttr::get(type));
+  result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr(
+                                                static_cast(linkage)));
   result.attributes.append(attrs.begin(), attrs.end());
   if (argAttrs.empty())
     return;
@@ -1137,15 +1139,16 @@ void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name,
       result.addAttribute(getArgAttrName(i, argAttrName), argDict);
 }
 
-// Build an LLVM function type from the given lists of input and output types.
+// Builds an LLVM function type from the given lists of input and output types.
 // Returns a null type if any of the types provided are non-LLVM types, or if
 // there is more than one output type.
-static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs,
-                                  ArrayRef outputs,
-                                  impl::VariadicFlag variadicFlag,
-                                  std::string &errorMessage) {
+static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
+                                  ArrayRef inputs, ArrayRef outputs,
+                                  impl::VariadicFlag variadicFlag) {
+  Builder &b = parser.getBuilder();
   if (outputs.size() > 1) {
-    errorMessage = "expected zero or one function result";
+    parser.emitError(loc, "failed to construct function type: expected zero or "
+                          "one function result");
     return {};
   }
 
@@ -1154,7 +1157,8 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs,
   for (auto t : inputs) {
     auto llvmTy = t.dyn_cast();
     if (!llvmTy) {
-      errorMessage = "expected LLVM type for function arguments";
+      parser.emitError(loc, "failed to construct function type: expected LLVM "
+                            "type for function arguments");
       return {};
     }
     llvmInputs.push_back(llvmTy);
@@ -1170,16 +1174,71 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs,
   LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
                                         : outputs.front().dyn_cast();
   if (!llvmOutput) {
-    errorMessage = "expected LLVM type for function results";
+    parser.emitError(loc, "failed to construct function type: expected LLVM "
+                          "type for function results");
     return {};
   }
   return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
                                  variadicFlag.isVariadic());
 }
 
-// Print the LLVMFuncOp.  Collects argument and result types and passes them
-// to the trait printer.  Drops "void" result since it cannot be parsed back.
+// Parses an LLVM function.
+//
+// operation ::= `llvm.func` linkage? function-signature function-attributes?
+//               function-body
+//
+static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  // Default to external linkage if no keyword is provided.
+  if (failed(parseOptionalLinkageKeyword(parser, result)))
+    result.addAttribute(getLinkageAttrName(),
+                        parser.getBuilder().getI64IntegerAttr(
+                            static_cast(LLVM::Linkage::External)));
+
+  StringAttr nameAttr;
+  SmallVector entryArgs;
+  SmallVector, 1> argAttrs;
+  SmallVector, 1> resultAttrs;
+  SmallVector argTypes;
+  SmallVector resultTypes;
+  bool isVariadic;
+
+  auto signatureLocation = parser.getCurrentLocation();
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes) ||
+      impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs,
+                                   argTypes, argAttrs, isVariadic, resultTypes,
+                                   resultAttrs))
+    return failure();
+
+  auto type =
+      buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
+                            impl::VariadicFlag(isVariadic));
+  if (!type)
+    return failure();
+  result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type));
+
+  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
+    return failure();
+  impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs,
+                             resultAttrs);
+
+  auto *body = result.addRegion();
+  return parser.parseOptionalRegion(
+      *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef() : argTypes);
+}
+
+// Print the LLVMFuncOp. Collects argument and result types and passes them to
+// helper functions. Drops "void" result since it cannot be parsed back. Skips
+// the external linkage since it is the default value.
 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
+  p << op.getOperationName() << ' ';
+  if (op.linkage() != LLVM::Linkage::External) {
+    printLinkage(p, op.linkage());
+    p << ' ';
+  }
+  p.printSymbolName(op.getName());
+
   LLVMType fnType = op.getType();
   SmallVector argTypes;
   SmallVector resTypes;
@@ -1191,7 +1250,15 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
   if (!returnType.getUnderlyingType()->isVoidTy())
     resTypes.push_back(returnType);
 
-  impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes);
+  impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
+  impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(),
+                                {getLinkageAttrName()});
+
+  // Print the body if this is not an external function.
+  Region &body = op.body();
+  if (!body.empty())
+    p.printRegion(body, /*printEntryBlockArgs=*/false,
+                  /*printBlockTerminators=*/true);
 }
 
 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
@@ -1227,9 +1294,26 @@ unsigned LLVMFuncOp::getNumFuncResults() {
   return 1;
 }
 
+// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
+// - functions don't have 'common' linkage
+// - external functions have 'external' or 'extern_weak' linkage;
+// - vararg is (currently) only supported for external functions;
+// - entry block arguments are of LLVM types and match the function signature.
 static LogicalResult verify(LLVMFuncOp op) {
-  if (op.isExternal())
+  if (op.linkage() == LLVM::Linkage::Common)
+    return op.emitOpError()
+           << "functions cannot have '" << linkageToStr(LLVM::Linkage::Common)
+           << "' linkage";
+
+  if (op.isExternal()) {
+    if (op.linkage() != LLVM::Linkage::External &&
+        op.linkage() != LLVM::Linkage::ExternWeak)
+      return op.emitOpError()
+             << "external functions must have '"
+             << linkageToStr(LLVM::Linkage::External) << "' or '"
+             << linkageToStr(LLVM::Linkage::ExternWeak) << "' linkage";
     return success();
+  }
 
   if (op.isVarArg())
     return op.emitOpError("only external functions can be variadic");
diff --git a/third_party/mlir/lib/IR/FunctionImplementation.cpp b/third_party/mlir/lib/IR/FunctionImplementation.cpp
index a1fc21e11ea..66c0d8af6d3 100644
--- a/third_party/mlir/lib/IR/FunctionImplementation.cpp
+++ b/third_party/mlir/lib/IR/FunctionImplementation.cpp
@@ -71,7 +71,8 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic,
   };
 
   // Parse the function arguments.
-  if (parser.parseOptionalRParen()) {
+  isVariadic = false;
+  if (failed(parser.parseOptionalRParen())) {
     do {
       unsigned numTypedArguments = argTypes.size();
       if (parseArgument())

From 18c7bad93bf8f50a5abd3a56f0d94b8040adaf4f Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 00:45:08 -0800
Subject: [PATCH 207/279] Add deprecation note for GL delegate.

PiperOrigin-RevId: 283495266
Change-Id: I40a7477e80fd59db6bc682e94088a84483d5f50e
---
 tensorflow/lite/delegates/gpu/BUILD         |  1 +
 tensorflow/lite/delegates/gpu/gl_delegate.h | 11 +++++++++++
 2 files changed, 12 insertions(+)

diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 83fa5872a0f..4cfbeff2081 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -37,6 +37,7 @@ cc_library(
         "//conditions:default": [],
     }),
     deps = [
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/types:span",
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite:minimal_logging",
diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.h b/tensorflow/lite/delegates/gpu/gl_delegate.h
index f1d30fd946e..bfc15fb120e 100644
--- a/tensorflow/lite/delegates/gpu/gl_delegate.h
+++ b/tensorflow/lite/delegates/gpu/gl_delegate.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include 
 
 #include 
+#include "absl/base/macros.h"
 #include "tensorflow/lite/c/common.h"
 
 #ifdef SWIG
@@ -39,6 +40,15 @@ limitations under the License.
 extern "C" {
 #endif  // __cplusplus
 
+// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
+//
+// GPU delegate declared in this file is OBSOLETE and replaced with the delegate
+// declared in delegate.h. New delegate combines all GL, CL and soon
+// Vulkan-based implementations in one.
+// Please migrate before end of 2019.
+//
+// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
+
 // LINT.IfChange
 enum TfLiteGlObjectType {
   TFLITE_GL_OBJECT_TYPE_FASTEST = 0,
@@ -109,6 +119,7 @@ TFL_CAPI_EXPORT TfLiteGpuDelegateOptions TfLiteGpuDelegateOptionsDefault();
 //   .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST,
 //   .dynamic_batch_enabled = false,
 // },
+ABSL_DEPRECATED("Use TfLiteGpuDelegateV2Create defined in delegate.h instead.")
 TFL_CAPI_EXPORT TfLiteDelegate* TfLiteGpuDelegateCreate(
     const TfLiteGpuDelegateOptions* options);
 

From 12bacab7bb32db44c29f9532a6172e084e83e042 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 01:02:53 -0800
Subject: [PATCH 208/279] compat: Update forward compatibility horizon to
 2019-12-03

PiperOrigin-RevId: 283497411
Change-Id: I83d62bc06a5d7fc84680921fbd99a1778b57fa04
---
 tensorflow/python/compat/compat.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 69e6fdd100a..6c3d92593b2 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 2)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 3)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 

From 8fad723f0487b42c17d7f9e7770cb51dddd30125 Mon Sep 17 00:00:00 2001
From: Adrian Kuegel 
Date: Tue, 3 Dec 2019 01:30:07 -0800
Subject: [PATCH 209/279] Add replay_computation target for mlir_gpu backend.

PiperOrigin-RevId: 283500767
Change-Id: I2839e5e698427a9df1997aed2e57d5b865bfb7e6
---
 tensorflow/compiler/xla/tools/BUILD | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 09d2d76cabf..da20d28ea81 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -95,6 +95,14 @@ tf_cc_binary(
     ],
 )
 
+tf_cc_binary(
+    name = "replay_computation_mlir_gpu",
+    deps = [
+        ":replay_computation_library",
+        "//tensorflow/compiler/xla/service:mlir_gpu_plugin",
+    ],
+)
+
 tf_cc_binary(
     name = "replay_computation_interpreter",
     deps = [

From a56901ffc97d9e6e96555e4879b15a4166810c1c Mon Sep 17 00:00:00 2001
From: Alexander Belyaev 
Date: Tue, 3 Dec 2019 01:55:18 -0800
Subject: [PATCH 210/279] [Linalg] Update/fix documentation for
 linalg.indexed_generic.

PiperOrigin-RevId: 283503642
Change-Id: Iee41c0a75c73e0f9b6e5359b71b1f66f814e100c
---
 .../mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
index e0070a8da35..afaf039ffd5 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
@@ -567,8 +567,8 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
         To support inplace updates in a generic fashion, the signature of the
         function must be:
         ```
-          fun([input views element types], [output views element types])
-            -> ([output views element types])
+          fun([index types for induction variables], [input views element types],
+              [output views element types]) -> ([output views element types])
         ```
       - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
         and output view. Such AffineMapAttr specifies the mapping between the
@@ -587,7 +587,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
     Example:
     Defining a #matmul_trait attribute in MLIR can be done as follows:
       ```mlir
-        func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
+        func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
           %d = mulf %a, %b: f32
           %e = addf %c, %d: f32
           return %e: f32

From 95f838a3d5db2f8e677493f5c124d16097acc71a Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Tue, 3 Dec 2019 04:49:20 -0800
Subject: [PATCH 211/279] [spirv] Add spv.SubgroupBallotKHROp

PiperOrigin-RevId: 283522284
Change-Id: Ie11fbd0f548ef4569bab9cce05848f2d7ab2fdc1
---
 third_party/mlir/BUILD                        |  1 +
 .../include/mlir/Dialect/SPIRV/SPIRVBase.td   |  8 +-
 .../mlir/Dialect/SPIRV/SPIRVGroupOps.td       | 74 +++++++++++++++++++
 .../include/mlir/Dialect/SPIRV/SPIRVOps.td    |  7 +-
 .../mlir/lib/Dialect/SPIRV/SPIRVOps.cpp       | 27 ++++++-
 .../mlir/utils/spirv/gen_spirv_dialect.py     |  6 +-
 6 files changed, 112 insertions(+), 11 deletions(-)
 create mode 100644 third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td

diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 76aecb2088c..73fbf86cde7 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -951,6 +951,7 @@ filegroup(
         "include/mlir/Dialect/SPIRV/SPIRVCastOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td",
+        "include/mlir/Dialect/SPIRV/SPIRVGroupOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVStructureOps.td",
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index e1897a9e295..bfb7497aada 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -953,6 +953,8 @@ class SPV_ScalarOrVectorOf :
 def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
 def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
 
+def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>;
+
 // TODO(antiagainst): Use a more appropriate way to model optional operands
 class SPV_Optional : Variadic;
 
@@ -1107,6 +1109,7 @@ def SPV_OC_OpReturn                 : I32EnumAttrCase<"OpReturn", 253>;
 def SPV_OC_OpReturnValue            : I32EnumAttrCase<"OpReturnValue", 254>;
 def SPV_OC_OpUnreachable            : I32EnumAttrCase<"OpUnreachable", 255>;
 def SPV_OC_OpModuleProcessed        : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPV_OC_OpSubgroupBallotKHR      : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 
 def SPV_OpcodeAttr :
     I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -1146,10 +1149,9 @@ def SPV_OpcodeAttr :
       SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
       SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
       SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
-      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed
+      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
+      SPV_OC_OpSubgroupBallotKHR
       ]> {
-    let returnType = "::mlir::spirv::Opcode";
-    let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
     let cppNamespace = "::mlir::spirv";
 }
 
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
new file mode 100644
index 00000000000..5f60e6b0135
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
@@ -0,0 +1,74 @@
+//===-- SPIRVGroupOps.td - MLIR SPIR-V (Sub)Group Ops ------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains group and subgroup ops for the SPIR-V dialect. It
+// corresponds to "3.32.21. Group and Subgroup Instructions" of the SPIR-V
+// specification.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_GROUP_OPS
+#define SPIRV_GROUP_OPS
+
+// -----
+
+def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
+  let summary = "See extension SPV_KHR_shader_ballot";
+
+  let description = [{
+    Computes a bitfield value combining the Predicate value from all invocations
+    in the current Subgroup that execute the same dynamic instance of this
+    instruction. The bit is set to one if the corresponding invocation is active
+    and the predicate is evaluated to true; otherwise, it is set to zero.
+
+    Predicate must be a Boolean type.
+
+    Result Type must be a 4 component vector of 32 bit integer types.
+
+    Result is a set of bitfields where the first invocation is represented in bit
+    0 of the first vector component and the last (up to SubgroupSize) is the
+    higher bit number of the last bitmask needed to represent all bits of the
+    subgroup invocations.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    subgroup-ballot-op ::= ssa-id `=` `spv.SubgroupBallotKHR`
+                                ssa-use `:` `vector` `<` 4 `x` `i32` `>`
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.SubgroupBallotKHR %predicate : vector<4xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_Bool:$predicate
+  );
+
+  let results = (outs
+    SPV_I32Vec4:$result
+  );
+
+  let verifier = [{ return success(); }];
+}
+
+// -----
+
+#endif // SPIRV_GROUP_OPS
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 41d729da777..178db0add4e 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -34,11 +34,10 @@ include "mlir/Dialect/SPIRV/SPIRVArithmeticOps.td"
 include "mlir/Dialect/SPIRV/SPIRVBitOps.td"
 include "mlir/Dialect/SPIRV/SPIRVCastOps.td"
 include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
-include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
-// Pull in ops for defining the SPIR-V module structure
-include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
-// Pull in ops for extended instruction set for GLSL
 include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
+include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
+include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
+include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
 
 // -----
 
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index e8896fac526..6e115f7ba76 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2526,6 +2526,28 @@ static LogicalResult verify(spirv::StoreOp storeOp) {
   return verifyMemoryAccessAttribute(storeOp);
 }
 
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBallotKHROp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
+                                            OperationState &state) {
+  OpAsmParser::OperandType operandInfo;
+  Type resultType;
+  IntegerType i1Type = parser.getBuilder().getI1Type();
+  if (parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
+      parser.resolveOperand(operandInfo, i1Type, state.operands))
+    return failure();
+
+  return parser.addTypeToList(resultType, state.types);
+}
+
+static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
+  printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ';
+  printer.printOperand(ballotOp.predicate());
+  printer << " : " << ballotOp.getType();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.Undef
 //===----------------------------------------------------------------------===//
@@ -2595,11 +2617,10 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
   state.addTypes(ptrType);
 
   // Resolve the initializer operand
-  SmallVector init;
   if (initInfo) {
-    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), init))
+    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
+                              state.operands))
       return failure();
-    state.addOperands(init);
   }
 
   auto attr = parser.getBuilder().getI32IntegerAttr(
diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py
index 5ef56675a1a..d1530f77d5a 100755
--- a/third_party/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py
@@ -426,7 +426,11 @@ def get_op_definition(instruction, doc, existing_info):
   # Make sure we have ', ' to separate the category arguments from traits
   category_args = category_args.rstrip(', ') + ', '
 
-  summary, text = doc.split('\n', 1)
+  if '\n' in doc:
+    summary, text = doc.split('\n', 1)
+  else:
+    summary = doc
+    text = ''
   wrapper = textwrap.TextWrapper(
       width=76, initial_indent='    ', subsequent_indent='    ')
 

From fd62258250257e4af380909a5833814023508ae4 Mon Sep 17 00:00:00 2001
From: Stephan Herhut 
Date: Tue, 3 Dec 2019 05:11:20 -0800
Subject: [PATCH 212/279] Extend conversion of SubViewOp to llvm to also
 support cases where size and stride are constant (i.e., there are no size and
 stride operands).

We recently added canonicalization that rewrites constant size and stride operands to
SubViewOp into static information in the type, so these patterns now occur during code
generation.

PiperOrigin-RevId: 283524688
Change-Id: Idfe28b4e36d84daa9a3d7f1d16404ded6cac7ccd
---
 .../StandardToLLVM/ConvertStandardToLLVM.cpp  | 30 +++++++++++++++----
 1 file changed, 24 insertions(+), 6 deletions(-)

diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 2db02db7f0c..0d932208893 100644
--- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1506,10 +1506,12 @@ struct SubViewOpLowering : public LLVMLegalizationPattern {
     if (!sourceElementTy || !targetDescTy)
       return matchFailure();
 
-    // Early exit for 0-D and operands lesser than `rank` corner cases.
+    // Currently, only rank > 0 and full or no operands are supported. Fail to
+    // convert otherwise.
     unsigned rank = sourceMemRefType.getRank();
-    if (viewMemRefType.getRank() == 0 || rank != dynamicOffsets.size() ||
-        rank != dynamicSizes.size() || rank != dynamicStrides.size())
+    if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) ||
+        (!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
+        (!dynamicStrides.empty() && rank != dynamicStrides.size()))
       return matchFailure();
 
     int64_t offset;
@@ -1539,6 +1541,17 @@ struct SubViewOpLowering : public LLVMLegalizationPattern {
     for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
 
+    // Fill in missing dynamic sizes.
+    auto llvmIndexType = lowering.convertType(rewriter.getIndexType());
+    if (dynamicSizes.empty()) {
+      dynamicSizes.reserve(viewMemRefType.getRank());
+      auto shape = viewMemRefType.getShape();
+      for (auto extent : shape) {
+        dynamicSizes.push_back(rewriter.create(
+            loc, llvmIndexType, rewriter.getI64IntegerAttr(extent)));
+      }
+    }
+
     // Offset.
     Value *baseOffset = sourceMemRef.offset(rewriter, loc);
     for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
@@ -1552,9 +1565,14 @@ struct SubViewOpLowering : public LLVMLegalizationPattern {
     // Update sizes and strides.
     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
       targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]);
-      targetMemRef.setStride(rewriter, loc, i,
-                             rewriter.create(
-                                 loc, dynamicStrides[i], strideValues[i]));
+      Value *newStride;
+      if (dynamicStrides.empty())
+        newStride = rewriter.create(
+            loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
+      else
+        newStride = rewriter.create(loc, dynamicStrides[i],
+                                                 strideValues[i]);
+      targetMemRef.setStride(rewriter, loc, i, newStride);
     }
 
     rewriter.replaceOp(op, {targetMemRef});

From 2cfa79f98497cd7d9ce9d26de38ec046323d1d31 Mon Sep 17 00:00:00 2001
From: Diego Caballero 
Date: Tue, 3 Dec 2019 06:09:21 -0800
Subject: [PATCH 213/279] AffineLoopFusion: Prevent fusion of multi-out-edge
 producer loops

https://github.com/tensorflow/mlir/pull/162 introduced a bug that
incorrectly allowed fusion of producer loops with multiple outgoing
edges. This commit fixes that problem. It also introduces a new flag to
disable sibling loop fusion so that we can test producer-consumer fusion
in isolation.

Closes #259

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/259 from dcaballe:dcaballe/fix_multi_out_edge_producer_fusion 578d5661705fd5c56c555832d5e0528df88c5282
PiperOrigin-RevId: 283531105
Change-Id: I3a6173463ea20bd35555c24fa451bfbf2dfac098
---
 third_party/mlir/lib/Transforms/LoopFusion.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/third_party/mlir/lib/Transforms/LoopFusion.cpp b/third_party/mlir/lib/Transforms/LoopFusion.cpp
index 7985ca1c5ef..cda35297abc 100644
--- a/third_party/mlir/lib/Transforms/LoopFusion.cpp
+++ b/third_party/mlir/lib/Transforms/LoopFusion.cpp
@@ -1005,17 +1005,19 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
 
 // Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
 // may write to multiple memrefs but it is required that only one of them,
-// 'srcLiveOutStoreOp', have an output edge.
+// 'srcLiveOutStoreOp', has output edges.
 // Returns true if 'dstNode's read/write region to 'memref' is a super set of
-// 'srcNode's write region to 'memref'.
+// 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
 // TODO(andydavis) Generalize this to handle more live in/out cases.
 static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
                                            AffineStoreOp srcLiveOutStoreOp,
                                            MemRefDependenceGraph *mdg) {
   assert(srcLiveOutStoreOp && "Expected a valid store op");
-  assert(mdg->getOutEdgeCount(srcId) == 1 && "Expected only one output edge");
   auto *dstNode = mdg->getNode(dstId);
   Value *memref = srcLiveOutStoreOp.getMemRef();
+  // Return false if 'srcNode' has more than one output edge on 'memref'.
+  if (mdg->getOutEdgeCount(srcId, memref) > 1)
+    return false;
 
   // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
   MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());

From 3b94eb4fac3f81f33c4a5b92e5641d9a2c865909 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 06:22:31 -0800
Subject: [PATCH 214/279] Fix ViewOp to have at most one offset operand

As described in the documentation, ViewOp is expected to take an optional
dynamic offset followed by a list of dynamic sizes. However, the ViewOp parser
did not include a check for the offset being a single value and accepeted a
list of values instead.

Furthermore, several tests have been exercising the wrong syntax of a ViewOp,
passing multiple values to the dyanmic stride list, which was not caught by the
parser. The trailing values could have been erronously interpreted as dynamic
sizes. This is likely due to resyntaxing of the ViewOp, with the previous
syntax taking the list of sizes before the offset. Update the tests to use the
syntax with the offset preceding the sizes.

Worse, the conversion of ViewOp to the LLVM dialect assumed the wrong order of
operands with offset in the trailing position, and erronously relied on the
permissive parsing that interpreted trailing dynamic offset values as leading
dynamic sizes. Fix the lowering to use the correct order of operands.

PiperOrigin-RevId: 283532506
Change-Id: I4d80570132c0e1194d865f6282b87e6b89e9879d
---
 .../StandardToLLVM/ConvertStandardToLLVM.cpp    | 17 +++++++++++------
 .../mlir/lib/Dialect/StandardOps/Ops.cpp        | 10 ++++++++--
 2 files changed, 19 insertions(+), 8 deletions(-)

diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 0d932208893..793997e9045 100644
--- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1663,13 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern {
     // Field 3: Copy the offset in aligned pointer.
     unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
     (void)numDynamicSizes;
+    bool hasDynamicOffset = offset == MemRefType::getDynamicStrideOrOffset();
     auto sizeAndOffsetOperands = adaptor.operands();
-    assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 ||
-           offset != MemRefType::getDynamicStrideOrOffset());
-    Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset())
+    assert(llvm::size(sizeAndOffsetOperands) ==
+           numDynamicSizes + (hasDynamicOffset ? 1 : 0));
+    Value *baseOffset = !hasDynamicOffset
                             ? createIndexConstant(rewriter, loc, offset)
                             // TODO(ntv): better adaptor.
-                            : sizeAndOffsetOperands.back();
+                            : sizeAndOffsetOperands.front();
     targetMemRef.setOffset(rewriter, loc, baseOffset);
 
     // Early exit for 0-D corner case.
@@ -1681,10 +1682,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern {
       return op->emitWarning("cannot cast to non-contiguous shape"),
              matchFailure();
     Value *stride = nullptr, *nextSize = nullptr;
+    // Drop the dynamic stride from the operand list, if present.
+    ArrayRef sizeOperands(sizeAndOffsetOperands);
+    if (hasDynamicOffset)
+      sizeOperands = sizeOperands.drop_front();
     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
       // Update size.
-      Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
-                            sizeAndOffsetOperands, i);
+      Value *size =
+          getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i);
       targetMemRef.setSize(rewriter, loc, i, size);
       // Update stride.
       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
diff --git a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
index 31431be5054..361135c4e29 100644
--- a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -2327,9 +2327,15 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
   SmallVector sizesInfo;
   auto indexType = parser.getBuilder().getIndexType();
   Type srcType, dstType;
+  llvm::SMLoc offsetLoc;
+  if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
+      parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
+    return failure();
+
+  if (offsetInfo.size() > 1)
+    return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand";
+
   return failure(
-      parser.parseOperand(srcInfo) ||
-      parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) ||
       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
       parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonType(srcType) ||

From 11896b96786181c4d2b7250bd5af7749c46fadf9 Mon Sep 17 00:00:00 2001
From: Andy Ly 
Date: Tue, 3 Dec 2019 08:22:40 -0800
Subject: [PATCH 215/279] Remove `output_arrays` from GraphImportConfig and
 rename `output_arrays_order` to `outputs`.

Output node names can be derived from `output_arrays_order` and current uses of `output_arrays` is always a set based off of `output_arrays_order`.

PiperOrigin-RevId: 283549297
Change-Id: I3ef230b3b2d10270e00777c1397e2302986c3deb
---
 .../compiler/mlir/lite/flatbuffer_import.cc   |  9 ++---
 .../lite/python/graphdef_to_tfl_flatbuffer.cc |  4 +--
 tensorflow/compiler/mlir/tensorflow/BUILD     |  2 ++
 .../mlir/tensorflow/translate/import_model.cc | 33 +++++++++++--------
 .../translate/mlir_roundtrip_flags.cc         | 11 +++----
 .../translate/mlir_roundtrip_flags.h          | 14 +++-----
 .../tensorflow/translate/tf_mlir_translate.cc | 15 +++++----
 tensorflow/compiler/tf2xla/mlir_tf2xla.cc     |  3 +-
 8 files changed, 46 insertions(+), 45 deletions(-)

diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index 1a459477ac1..8f24ad441a6 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -880,11 +880,8 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
       mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
 
   // Parses output_arrays_order from command line option.
-  absl::flat_hash_set output_set;
-  std::vector output_arrays_order;
-  if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &output_set,
-                                        &output_arrays_order)
-           .ok()) {
+  std::vector outputs;
+  if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
     return emitError(loc, "parsing output array info failed ")
                << output_arrays_string,
            nullptr;
@@ -892,7 +889,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
 
   return tflite::FlatBufferToMlir(
       absl::string_view(input->getBufferStart(), input->getBufferSize()),
-      context, loc, output_arrays_order);
+      context, loc, outputs);
 }
 
 static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
index 895d12f61ef..51bd1e4540c 100644
--- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
@@ -237,8 +237,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
   // Parse output arrays.
   std::vector output_arrays(model_flags.output_arrays().begin(),
                                     model_flags.output_arrays().end());
-  TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(
-      output_arrays, &specs.output_arrays, &specs.output_arrays_order));
+  TF_RETURN_IF_ERROR(
+      tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
 
   // Other flags.
   bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 7cfa802b1c3..5484988d0f5 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -469,6 +469,7 @@ cc_library(
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/platform:types",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
@@ -671,6 +672,7 @@ cc_library(
         ":import_utils",
         ":mangling_util",
         ":mlir_roundtrip_flags",
+        "//tensorflow/core:graph",
         "//tensorflow/core:lib_proto_parsing",
         "//tensorflow/core:ops",
         "//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index ef4b1d1682e..da2e6a67445 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -21,6 +21,7 @@ limitations under the License.
 
 #include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/strings/escaping.h"
 #include "absl/strings/numbers.h"
@@ -75,6 +76,7 @@ limitations under the License.
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/tensor_id.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/protobuf.h"
@@ -500,7 +502,8 @@ Status ImporterBase::GetInputOutputNodes(
     TF_RETURN_IF_ERROR(add_node(input.first));
   }
 
-  for (const auto& output_node_name : specs_.output_arrays) {
+  for (const auto& output : specs_.outputs) {
+    auto output_node_name = std::string(ParseTensorName(output).first);
     TF_RETURN_IF_ERROR(add_node(output_node_name));
   }
 
@@ -1588,7 +1591,7 @@ StatusOr GraphDefImporter::Convert(
   llvm::SmallVector attrs;
   if (specs.graph_as_function) {
     if (specs.prune_unused_nodes || !specs.inputs.empty() ||
-        !specs.output_arrays.empty() || !specs.output_arrays_order.empty())
+        !specs.outputs.empty())
       return errors::InvalidArgument(
           "Pruning of graph is currently unsupported when the main graph is "
           "converted to a function.");
@@ -1622,7 +1625,7 @@ StatusOr GraphDefImporter::Convert(
     // TODO(prakalps): Refactor to keep attribute strings (tf.entry_function,
     // tf.versions) shared by importer and exporter in a centralized place.
     // Record the input and output mapping.
-    if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
+    if (!specs.inputs.empty() || !specs.outputs.empty()) {
       mlir::Builder b(context);
       std::string s;
       llvm::raw_string_ostream ss(s);
@@ -1632,7 +1635,7 @@ StatusOr GraphDefImporter::Convert(
           ",");
       auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
       s.clear();
-      mlir::interleave(specs.output_arrays_order, ss, ",");
+      mlir::interleave(specs.outputs, ss, ",");
       auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
 
       attrs.push_back(b.getNamedAttr("tf.entry_function",
@@ -1665,9 +1668,13 @@ StatusOr GraphDefImporter::InferMainFunctionType(
     absl::InlinedVector* arg_nodes,
     absl::InlinedVector* ret_nodes) {
   // Finds out all the input nodes and output nodes.
-  if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
+  absl::flat_hash_set output_node_names;
+  for (const auto& output_tensor : specs.outputs) {
+    output_node_names.insert(ParseTensorName(output_tensor).node());
+  }
+  if (!specs.inputs.empty() || !specs.outputs.empty()) {
     arg_nodes->resize(specs.inputs.size());
-    ret_nodes->resize(specs.output_arrays_order.size());
+    ret_nodes->resize(specs.outputs.size());
 
     for (Node* n : GetOrderedNodes()) {
       // Handle inputs/arguments.
@@ -1677,17 +1684,17 @@ StatusOr GraphDefImporter::InferMainFunctionType(
       }
 
       // Handle outputs/returns.
-      if (specs.output_arrays.find(n->name()) != specs.output_arrays.end()) {
-        for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
+      if (output_node_names.contains(n->name())) {
+        for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
           std::pair name_and_port =
-              absl::StrSplit(specs.output_arrays_order[i], ':');
+              absl::StrSplit(specs.outputs[i], ':');
           auto name = name_and_port.first;
           if (name != n->name()) continue;
           int port = 0;
           if (!name_and_port.second.empty() &&
               !absl::SimpleAtoi(name_and_port.second, &port)) {
             return errors::InvalidArgument("Invalid port specification: ",
-                                           specs.output_arrays_order[i]);
+                                           specs.outputs[i]);
           }
           (*ret_nodes)[i] = {n, port};
         }
@@ -1726,10 +1733,10 @@ StatusOr GraphDefImporter::InferMainFunctionType(
   }
 
   llvm::SmallVector ret_types;
-  ret_types.reserve(specs.output_arrays.size());
-  for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
+  ret_types.reserve(specs.outputs.size());
+  for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
     if (ret_nodes->at(i).node == nullptr) {
-      return errors::InvalidArgument("Output ", specs.output_arrays_order[i],
+      return errors::InvalidArgument("Output ", specs.outputs[i],
                                      " was not found in graph");
     }
   }
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
index 133e4831356..b2cf906be0d 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
@@ -33,19 +33,16 @@ limitations under the License.
 namespace tensorflow {
 
 Status ParseOutputArrayInfo(absl::string_view array_names,
-                            absl::flat_hash_set* array,
-                            std::vector* order) {
+                            std::vector* outputs) {
   std::vector output_names = absl::StrSplit(array_names, ',');
-  return ParseOutputArrayInfo(output_names, array, order);
+  return ParseOutputArrayInfo(output_names, outputs);
 }
 
 Status ParseOutputArrayInfo(const std::vector& output_names,
-                            absl::flat_hash_set* array,
-                            std::vector* order) {
+                            std::vector* outputs) {
   for (auto& output_name : output_names) {
     if (output_name.empty()) continue;
-    array->insert(string(*absl::StrSplit(output_name, ':').begin()));
-    order->push_back(output_name);
+    outputs->push_back(output_name);
   }
   return Status::OK();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h
index ebc862999e9..9b260883638 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h
@@ -40,11 +40,9 @@ struct GraphImportConfig {
       llvm::MapVector>;
   // Maps input node names to node data types and shapes.
   InputArrays inputs;
-  // Output node names.
-  absl::flat_hash_set output_arrays;
-  // nodes:index strings for the output as specified on the command line.
-  std::vector output_arrays_order;
-  // setting prune_unused_nodes to true, would prune unreachable nodes if
+  // name:index strings for the output as specified on the command line.
+  std::vector outputs;
+  // Setting prune_unused_nodes to true, would prune unreachable nodes if
   // output_arrays is specified.
   bool prune_unused_nodes = false;
   // If true, inputs of type LegacyFedInput are replaced with Placeholder ops.
@@ -73,12 +71,10 @@ struct GraphExportConfig {
 // Parses the command line flag strings to the specification of nodes in
 // the Graph.
 Status ParseOutputArrayInfo(absl::string_view array_names,
-                            absl::flat_hash_set* array,
-                            std::vector* order);
+                            std::vector* outputs);
 
 Status ParseOutputArrayInfo(const std::vector& output_names,
-                            absl::flat_hash_set* array,
-                            std::vector* order);
+                            std::vector* outputs);
 
 // Parses the command line flag strings to the specification of nodes in
 // the Graph. `data_types` input string can be empty since the flag is optional.
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
index cd422a66bc5..5c59eace5cc 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
@@ -32,6 +32,7 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/tensor_id.h"
 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
@@ -63,16 +64,18 @@ static StatusOr GraphdefToMlirImport(
   specs.upgrade_legacy = upgrade_legacy;
   TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
                                          input_shapes, &specs.inputs));
-  TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.output_arrays,
-                                          &specs.output_arrays_order));
+  TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
   // TODO(b/142828368): Pruning should not be needed when TF import
   // supports importing graphs w/ unregistered ops natively.
   GraphDef pruned_graph_def;
   if (specs.prune_unused_nodes) {
-    std::vector terminal_nodes(specs.output_arrays.begin(),
-                                       specs.output_arrays.end());
-    for (const auto entry : specs.inputs) {
-      terminal_nodes.push_back(entry.first);
+    std::vector terminal_nodes;
+    terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
+    for (const auto& output : specs.outputs) {
+      terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
+    }
+    for (const auto& input : specs.inputs) {
+      terminal_nodes.push_back(input.first);
     }
     TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
         graphdef, &pruned_graph_def, terminal_nodes));
diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
index 7617f10a372..01325af3d39 100644
--- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
@@ -77,8 +77,7 @@ Status ConvertOutputInfo(const tf2xla::Config& config,
     array_names.push_back(fetch.id().node_name());
   }
 
-  return ParseOutputArrayInfo(array_names, &specs->output_arrays,
-                              &specs->output_arrays_order);
+  return ParseOutputArrayInfo(array_names, &specs->outputs);
 }
 
 }  // namespace

From 2454d7bd23173b0ca286c2072954629753edc60e Mon Sep 17 00:00:00 2001
From: Benoit Jacob 
Date: Tue, 3 Dec 2019 08:46:03 -0800
Subject: [PATCH 216/279] use nullptr for null pointers.

PiperOrigin-RevId: 283553199
Change-Id: I0d3c335d2f877f10ba11c1755c32d9568a2bd4be
---
 .../experimental/ruy/prepacked_cache_test.cc     | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc b/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc
index 0e912495d09..efb6f2b358c 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc
@@ -33,7 +33,7 @@ TEST(PrepackedCacheTest, TestCacheEjection) {
   mat1.data_size = 16;
   mat1.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat1);
-  auto cache_key1 = std::make_pair(reinterpret_cast(0), mat1.data);
+  auto cache_key1 = std::make_pair(nullptr, mat1.data);
   prepacked_cache.Insert(cache_key1, mat1);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
   // Get a time point after the insertion into the cache.
@@ -49,7 +49,7 @@ TEST(PrepackedCacheTest, TestCacheEjection) {
   mat2.sums_size = 4;
   prepacked_cache.AllocatePrepackedMatrix(&mat2);
 
-  auto cache_key2 = std::make_pair(reinterpret_cast(0), mat2.data);
+  auto cache_key2 = std::make_pair(nullptr, mat2.data);
   prepacked_cache.Insert(cache_key2, mat2);
   // The cache size was exceeded by inserting mat2. Ensure that mat1 was
   // ejected.
@@ -67,7 +67,7 @@ TEST(PrepackedCacheTest, TestCacheBasic) {
   mat1.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat1);
 
-  auto cache_key1 = std::make_pair(reinterpret_cast(0), mat1.data);
+  auto cache_key1 = std::make_pair(nullptr, mat1.data);
   prepacked_cache.Insert(cache_key1, mat1);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
   EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
@@ -77,7 +77,7 @@ TEST(PrepackedCacheTest, TestCacheBasic) {
   mat2.sums_size = 4;
   prepacked_cache.AllocatePrepackedMatrix(&mat2);
 
-  auto cache_key2 = std::make_pair(reinterpret_cast(0), mat2.data);
+  auto cache_key2 = std::make_pair(nullptr, mat2.data);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
   prepacked_cache.Insert(cache_key2, mat2);
   // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
@@ -95,7 +95,7 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
   mat1.data_size = 16;
   mat1.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat1);
-  auto cache_key1 = std::make_pair(reinterpret_cast(0), mat1.data);
+  auto cache_key1 = std::make_pair(nullptr, mat1.data);
   prepacked_cache.Insert(cache_key1, mat1);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
 
@@ -104,7 +104,7 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
   mat2.data_size = 16;
   mat2.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat2);
-  auto cache_key2 = std::make_pair(reinterpret_cast(0), mat2.data);
+  auto cache_key2 = std::make_pair(nullptr, mat2.data);
   prepacked_cache.Insert(cache_key2, mat2);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
 
@@ -113,7 +113,7 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
   mat31.data_size = 16;
   mat31.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat31);
-  auto cache_key3 = std::make_pair(reinterpret_cast(0), mat31.data);
+  auto cache_key3 = std::make_pair(nullptr, mat31.data);
   prepacked_cache.Insert(cache_key3, mat31);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
 
@@ -128,7 +128,7 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
   mat4.data_size = 16;
   mat4.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat4);
-  auto cache_key4 = std::make_pair(reinterpret_cast(0), mat4.data);
+  auto cache_key4 = std::make_pair(nullptr, mat4.data);
   prepacked_cache.Insert(cache_key4, mat4);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
 

From 6547b6511baa368e87901837c070fe614a72933a Mon Sep 17 00:00:00 2001
From: Sean Silva 
Date: Tue, 3 Dec 2019 09:01:54 -0800
Subject: [PATCH 217/279] Use separate allocator for cached prepacked matrix
 allocations.

This CL splits out the SystemAlignedAlloc/Free functions so that they are independently usable in this way.

ruy::Allocator is a highly specialized allocator designed for the hot path of multiple gemm's.

The use case of cached pre-packing has a very different set of tradeoffs.

PiperOrigin-RevId: 283555950
Change-Id: I012a6ba4386e1727866e677965b857f62992d5f1
---
 tensorflow/lite/experimental/ruy/allocator.cc |  6 +-
 tensorflow/lite/experimental/ruy/allocator.h  | 75 +++++++++----------
 .../lite/experimental/ruy/prepacked_cache.cc  | 13 +---
 .../lite/experimental/ruy/prepacked_cache.h   | 41 ++++++++--
 4 files changed, 78 insertions(+), 57 deletions(-)

diff --git a/tensorflow/lite/experimental/ruy/allocator.cc b/tensorflow/lite/experimental/ruy/allocator.cc
index 8c4536bdeb1..60a905136fd 100644
--- a/tensorflow/lite/experimental/ruy/allocator.cc
+++ b/tensorflow/lite/experimental/ruy/allocator.cc
@@ -26,19 +26,19 @@ namespace ruy {
 
 namespace detail {
 
-void *AlignedAllocator::SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
+void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
 #ifdef _WIN32
   return _aligned_malloc(num_bytes, kAlignment);
 #else
   void *ptr;
-  if (posix_memalign(&ptr, kAlignment, num_bytes)) {
+  if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) {
     return nullptr;
   }
   return ptr;
 #endif
 }
 
-void AlignedAllocator::SystemAlignedFree(void *ptr) {
+void SystemAlignedFree(void *ptr) {
 #ifdef _WIN32
   _aligned_free(ptr);
 #else
diff --git a/tensorflow/lite/experimental/ruy/allocator.h b/tensorflow/lite/experimental/ruy/allocator.h
index f233090ce49..2f5c98d6870 100644
--- a/tensorflow/lite/experimental/ruy/allocator.h
+++ b/tensorflow/lite/experimental/ruy/allocator.h
@@ -34,38 +34,49 @@ inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) {
   return reinterpret_cast(addr);
 }
 
-// Simple allocator designed to converge to a steady-state where all
+// Minimum alignment for blocks.
+//
+// Considerations:
+//  - This needs to be at least the alignment of any usual data type.
+//  - It's useful that this is at least the size of a cache line to limit
+//    possible cache side effects (if only on performance behavior).
+//  - It's useful that this is at least the size of SIMD registers, as
+//    some SIMD instruction sets have at least performance behavior
+//    differences (e.g. NEON) or even different requirements (e.g. SSE)
+//    based on that.
+//  - It's useful that this is at least the size of an "exclusive reservation
+//    granule" on ARM, meaning that if we use this Allocator to allocate
+//    an atomic variable, there will be no side effects from other things
+//    contending for exclusive/atomic memory accesses to it. While the
+//    ARM reference manual mentions that this granule size may be as large
+//    as 2048 bytes, in practice we observe it to be 64 bytes. It can
+//    be queried cheaply, at runtime, from userspace, if needed.
+static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64;
+
+// Primitive allocation functions obtaining aligned memory from the
+// operating system.
+void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
+void SystemAlignedFree(void* ptr);
+
+// Specialized allocator designed to converge to a steady-state where all
 // allocations are bump-ptr allocations from an already-allocated buffer.
 //
 // To support these constraints, this allocator only supports two
 // operations.
 // - AllocateAlignedBytes: allocates a pointer to storage of a specified
-// size, which must be aligned to kAlignment.
+// size, which must be aligned to kMinimumBlockAlignment.
 // - FreeAll: frees all previous allocations (but retains the internal
 // buffer to minimize future calls into the system allocator).
 //
+// This class is specialized for supporting just those two operations
+// under this specific steady-state usage pattern. Extending this class
+// with new allocation interfaces that don't fit that pattern is probably not
+// the right choice. Instead, build a new class on top of
+// SystemAlignedAlloc/SystemAlignedFree.
+//
 // All operations happen on aligned blocks for simplicity.
 class AlignedAllocator {
  public:
-  // Alignment of allocated blocks.
-  //
-  // Considerations:
-  //  - This needs to be at least the alignment of any usual data type.
-  //  - It's useful that this is at least the size of a cache line to limit
-  //    possible cache side effects (if only on performance behavior).
-  //  - It's useful that this is at least the size of SIMD registers, as
-  //    some SIMD instruction sets have at least performance behavior
-  //    differences (e.g. NEON) or even different requirements (e.g. SSE)
-  //    based on that.
-  //  - It's useful that this is at least the size of an "exclusive reservation
-  //    granule" on ARM, meaning that if we use this Allocator to allocate
-  //    an atomic variable, there will be no side effects from other things
-  //    contending for exclusive/atomic memory accesses to it. While the
-  //    ARM reference manual mentions that this granule size may be as large
-  //    as 2048 bytes, in practice we observe it to be 64 bytes. It can
-  //    be queried cheaply, at runtime, from userspace, if needed.
-  static constexpr std::ptrdiff_t kAlignment = 64;
-
   void operator=(const AlignedAllocator&) = delete;
   ~AlignedAllocator() {
     FreeAll();
@@ -74,7 +85,7 @@ class AlignedAllocator {
 
   void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) {
     RUY_DCHECK_GT(num_bytes, 0);
-    RUY_DCHECK((num_bytes & (kAlignment - 1)) == 0);
+    RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0);
     if (void* p = AllocateFast(num_bytes)) {
       return p;
     }
@@ -105,17 +116,7 @@ class AlignedAllocator {
     fallback_blocks_total_size_ = 0;
   }
 
-  void FreeOne(void* ptr) {
-    for (auto p = fallback_blocks_.begin(); p != fallback_blocks_.end(); ++p) {
-      if (*p == ptr) {
-        SystemAlignedFree(ptr);
-        fallback_blocks_.erase(p);
-        return;
-      }
-    }
-    RUY_DCHECK(false);  // Trying to free pointer we did not allocate.
-  }
-
+ private:
   void* AllocateFast(std::ptrdiff_t num_bytes) {
     if (current_ + num_bytes > size_) {
       return nullptr;
@@ -132,12 +133,6 @@ class AlignedAllocator {
     return p;
   }
 
- private:
-  // Primitive allocation functions obtaining aligned memory from the
-  // operating system.
-  void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
-  void SystemAlignedFree(void* ptr);
-
   // Theory of operation:
   //
   // - ptr_, current_, and size_ implement a basic bump-ptr allocator.
@@ -171,7 +166,7 @@ class Allocator {
       return nullptr;
     }
     return aligned.AllocateAlignedBytes(
-        round_up_pot(num_bytes, detail::AlignedAllocator::kAlignment));
+        round_up_pot(num_bytes, detail::kMinimumBlockAlignment));
   }
   template 
   void Allocate(std::ptrdiff_t count, Pointer* out) {
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.cc b/tensorflow/lite/experimental/ruy/prepacked_cache.cc
index 93fc4363044..2bd23f834c4 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache.cc
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache.cc
@@ -58,19 +58,14 @@ void PrepackedCache::EjectOne() {
   PrepackedMatrix &pmatrix = oldest->second.first;
   cache_size_ -= pmatrix.data_size;
   cache_size_ -= pmatrix.sums_size;
-  allocator_.FreeOne(pmatrix.data);
-  allocator_.FreeOne(pmatrix.sums);
+  allocator_.Free(pmatrix.data);
+  allocator_.Free(pmatrix.sums);
   cache_.erase(oldest);
 }
 
 void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) {
-  pmatrix->data = AllocateBytes(pmatrix->data_size);
-  pmatrix->sums = AllocateBytes(pmatrix->sums_size);
-}
-
-void *PrepackedCache::AllocateBytes(std::ptrdiff_t num_bytes) {
-  // Force system allocation for now to enable easy ejections.
-  return allocator_.AllocateSlow(num_bytes);
+  pmatrix->data = allocator_.Alloc(pmatrix->data_size);
+  pmatrix->sums = allocator_.Alloc(pmatrix->sums_size);
 }
 
 void PrepackedCache::DoInsert(const CacheKey &key,
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.h b/tensorflow/lite/experimental/ruy/prepacked_cache.h
index 053108e61ed..9c77c48cf69 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache.h
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache.h
@@ -16,6 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_
 
+#include 
 #include 
 #include 
 #include 
@@ -27,6 +28,40 @@ limitations under the License.
 
 namespace ruy {
 
+namespace detail {
+
+// Tracks a set of blocks allocated from the underlying system allocator.
+class SystemBlockAllocator {
+ public:
+  void *Alloc(std::ptrdiff_t num_bytes) {
+    void *p = detail::SystemAlignedAlloc(num_bytes);
+    blocks_.push_back(p);
+    return p;
+  }
+
+  void Free(void *block) {
+    for (auto it = blocks_.begin(); it != blocks_.end(); ++it) {
+      if (*it == block) {
+        detail::SystemAlignedFree(block);
+        blocks_.erase(it);
+        return;
+      }
+    }
+    RUY_DCHECK(false);  // Trying to free pointer we did not allocate.
+  }
+
+  ~SystemBlockAllocator() {
+    for (void *block : blocks_) {
+      detail::SystemAlignedFree(block);
+    }
+  }
+
+ private:
+  std::vector blocks_;
+};
+
+}  // namespace detail
+
 enum CachePolicy { kNoCache, kCacheLHSOnGemV };
 
 // "Low effort" Least Recently Used Cache for Prepacked Matrices
@@ -80,12 +115,8 @@ class PrepackedCache {
 
  private:
   void EjectOne();
-  void *AllocateBytes(std::ptrdiff_t num_bytes);
   void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix);
-  // Since this cache is used in the context of "pre-packing", we need to
-  // handle allocating the space for the packed matrix ourselves, so we need
-  // our own allocator.
-  AlignedAllocator allocator_;
+  detail::SystemBlockAllocator allocator_;
   std::map cache_;
   const int32_t ejection_threshold_;
   size_t cache_size_;

From e2c953b991243aafff71cea8a886f05e52b5ddb4 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 09:05:14 -0800
Subject: [PATCH 218/279] Use explicit inference priorities instead of setting
 just allow_precision_loss

PiperOrigin-RevId: 283556973
Change-Id: If0f38f068e9a743ebb32f149a98759ce1ea4abef
---
 tensorflow/lite/examples/label_image/label_image.cc | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc
index 452fa0c9682..a3d07a66a02 100644
--- a/tensorflow/lite/examples/label_image/label_image.cc
+++ b/tensorflow/lite/examples/label_image/label_image.cc
@@ -60,7 +60,9 @@ TfLiteDelegatePtr CreateGPUDelegate(Settings* s) {
   TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default();
   gpu_opts.inference_preference =
       TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED;
-  gpu_opts.is_precision_loss_allowed = s->allow_fp16 ? 1 : 0;
+  gpu_opts.inference_priority1 =
+      s->allow_fp16 ? TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY
+                    : TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
   return evaluation::CreateGPUDelegate(s->model, &gpu_opts);
 #else
   return evaluation::CreateGPUDelegate(s->model);

From 8080cd198090d80ed74ca5a58289a2c1d65a0bca Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 09:32:16 -0800
Subject: [PATCH 219/279] Add python bindings for ArrayAttr, AffineMapAttr.

PiperOrigin-RevId: 283561252
Change-Id: I934a7581b5b3e22529c7d79c5016050ba62b72d5
---
 third_party/mlir/bindings/python/pybind.cpp   | 124 +++++++++++++++++-
 .../mlir/bindings/python/test/test_py2and3.py |  35 +++++
 2 files changed, 157 insertions(+), 2 deletions(-)

diff --git a/third_party/mlir/bindings/python/pybind.cpp b/third_party/mlir/bindings/python/pybind.cpp
index a458837f77a..90b24cdbf4c 100644
--- a/third_party/mlir/bindings/python/pybind.cpp
+++ b/third_party/mlir/bindings/python/pybind.cpp
@@ -31,6 +31,8 @@
 #include "mlir/EDSC/Intrinsics.h"
 #include "mlir/ExecutionEngine/ExecutionEngine.h"
 #include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
@@ -62,6 +64,8 @@ struct PythonExpr;
 struct PythonFunctionContext;
 struct PythonStmt;
 struct PythonBlock;
+struct PythonAffineExpr;
+struct PythonAffineMap;
 
 struct PythonType {
   PythonType() : type{nullptr} {}
@@ -191,6 +195,25 @@ struct PythonMLIRModule {
   // Create a boolean attribute.
   PythonAttribute boolAttr(bool value);
 
+  // Creates an Array attribute.
+  PythonAttribute arrayAttr(const std::vector &values);
+
+  // Creates an AffineMap attribute.
+  PythonAttribute affineMapAttr(PythonAffineMap value);
+
+  // Creates an affine constant expression.
+  PythonAffineExpr affineConstantExpr(int64_t value);
+
+  // Creates an affine symbol expression.
+  PythonAffineExpr affineSymbolExpr(unsigned position);
+
+  // Creates a single constant result affine map.
+  PythonAffineMap affineConstantMap(int64_t value);
+
+  // Creates an affine map.
+  PythonAffineMap affineMap(unsigned dimCount, unsigned symbolCount,
+                            const std::vector &results);
+
   // Compile the module save the execution engine. "optLevel" and
   // "codegenOptLevel" contain the levels of optimization to run (0 to 3) for
   // transformations and codegen. -1 means ExecutionEngine default.
@@ -467,14 +490,15 @@ struct PythonAttribute {
   PythonAttribute(const PythonAttribute &other) = default;
   operator mlir_attr_t() { return attr; }
 
+  operator Attribute() const { return Attribute::getFromOpaquePointer(attr); }
+
   std::string str() const {
     if (!attr)
       return "##null attr##";
 
     std::string res;
     llvm::raw_string_ostream os(res);
-    Attribute::getFromOpaquePointer(reinterpret_cast(attr))
-        .print(os);
+    Attribute().print(os);
     return res;
   }
 
@@ -532,6 +556,46 @@ private:
   std::unordered_map attrs;
 };
 
+// Wraps mlir::AffineExpr.
+struct PythonAffineExpr {
+  PythonAffineExpr() : affine_expr() {}
+  PythonAffineExpr(const AffineExpr &a) : affine_expr(a) {}
+  PythonAffineExpr(const PythonAffineExpr &other) = default;
+
+  operator AffineExpr() const { return affine_expr; }
+  operator AffineExpr &() { return affine_expr; }
+
+  std::string str() const {
+    std::string res;
+    llvm::raw_string_ostream os(res);
+    affine_expr.print(os);
+    return res;
+  }
+
+private:
+  AffineExpr affine_expr;
+};
+
+// Wraps mlir::AffineMap.
+struct PythonAffineMap {
+  PythonAffineMap() : affine_map() {}
+  PythonAffineMap(const AffineMap &a) : affine_map(a) {}
+  PythonAffineMap(const PythonAffineMap &other) = default;
+
+  operator AffineMap() const { return affine_map; }
+  operator AffineMap &() { return affine_map; }
+
+  std::string str() const {
+    std::string res;
+    llvm::raw_string_ostream os(res);
+    affine_map.print(os);
+    return res;
+  }
+
+private:
+  AffineMap affine_map;
+};
+
 struct PythonIndexedValue {
   explicit PythonIndexedValue(PythonType type)
       : indexed(Type::getFromOpaquePointer(type.type)) {}
@@ -640,6 +704,38 @@ PythonAttribute PythonMLIRModule::boolAttr(bool value) {
   return PythonAttribute(::makeBoolAttr(&mlirContext, value));
 }
 
+PythonAttribute
+PythonMLIRModule::arrayAttr(const std::vector &values) {
+  std::vector mlir_attributes(values.begin(), values.end());
+  auto array_attr = ArrayAttr::get(
+      llvm::ArrayRef(mlir_attributes), &mlirContext);
+  return PythonAttribute(array_attr.getAsOpaquePointer());
+}
+
+PythonAttribute PythonMLIRModule::affineMapAttr(PythonAffineMap value) {
+  return PythonAttribute(AffineMapAttr::get(value).getAsOpaquePointer());
+}
+
+PythonAffineExpr PythonMLIRModule::affineConstantExpr(int64_t value) {
+  return PythonAffineExpr(getAffineConstantExpr(value, &mlirContext));
+}
+
+PythonAffineExpr PythonMLIRModule::affineSymbolExpr(unsigned position) {
+  return PythonAffineExpr(getAffineSymbolExpr(position, &mlirContext));
+}
+
+PythonAffineMap PythonMLIRModule::affineConstantMap(int64_t value) {
+  return PythonAffineMap(AffineMap::getConstantMap(value, &mlirContext));
+}
+
+PythonAffineMap
+PythonMLIRModule::affineMap(unsigned dimCount, unsigned SymbolCount,
+                            const std::vector &results) {
+  std::vector mlir_results(results.begin(), results.end());
+  return PythonAffineMap(AffineMap::get(
+      dimCount, SymbolCount, llvm::ArrayRef(mlir_results)));
+}
+
 PYBIND11_MODULE(pybind, m) {
   m.doc() =
       "Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
@@ -801,6 +897,12 @@ PYBIND11_MODULE(pybind, m) {
           "integerAttr", &PythonMLIRModule::integerAttr,
           "Creates an mlir::IntegerAttr of the given type with the given value "
           "in the context associated with this MLIR module.")
+      .def("arrayAttr", &PythonMLIRModule::arrayAttr,
+           "Creates an mlir::ArrayAttr of the given type with the given values "
+           "in the context associated with this MLIR module.")
+      .def("affineMapAttr", &PythonMLIRModule::affineMapAttr,
+           "Creates an mlir::AffineMapAttr of the given type with the given "
+           "value in the context associated with this MLIR module.")
       .def("declare_function", &PythonMLIRModule::declareFunction,
            "Declares a new mlir::FuncOp in the current mlir::ModuleOp.  The "
            "function arguments can have attributes.  The function has no "
@@ -831,6 +933,14 @@ PYBIND11_MODULE(pybind, m) {
       .def("get_engine_address", &PythonMLIRModule::getEngineAddress,
            "Returns the address of the compiled ExecutionEngine. This is used "
            "for in-process execution.")
+      .def("affine_constant_expr", &PythonMLIRModule::affineConstantExpr,
+           "Returns an affine constant expression.")
+      .def("affine_symbol_expr", &PythonMLIRModule::affineSymbolExpr,
+           "Returns an affine symbol expression.")
+      .def("affine_constant_map", &PythonMLIRModule::affineConstantMap,
+           "Returns an affine map with single constant result.")
+      .def("affine_map", &PythonMLIRModule::affineMap, "Returns an affine map.",
+           py::arg("dimCount"), py::arg("symbolCount"), py::arg("resuls"))
       .def("__str__", &PythonMLIRModule::getIR,
            "Get the string representation of the module");
 
@@ -940,6 +1050,16 @@ PYBIND11_MODULE(pybind, m) {
       .def(py::init())
       .def("load", &PythonIndexedValue::load)
       .def("store", &PythonIndexedValue::store);
+
+  py::class_(m, "AffineExpr",
+                               "A wrapper around mlir::AffineExpr")
+      .def(py::init())
+      .def("__str__", &PythonAffineExpr::str);
+
+  py::class_(m, "AffineMap",
+                              "A wrapper around mlir::AffineMap")
+      .def(py::init())
+      .def("__str__", &PythonAffineMap::str);
 }
 
 } // namespace python
diff --git a/third_party/mlir/bindings/python/test/test_py2and3.py b/third_party/mlir/bindings/python/test/test_py2and3.py
index 2f4281ee59a..cd4d7bfcecd 100644
--- a/third_party/mlir/bindings/python/test/test_py2and3.py
+++ b/third_party/mlir/bindings/python/test/test_py2and3.py
@@ -285,6 +285,41 @@ class EdscTest:
     # CHECK-LABEL: testFunctionDeclaration
     #       CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias = true}, memref<10xf32> {readonly = true})
 
+  def testFunctionDeclarationWithAffineAttr(self):
+    self.setUp()
+    a1 = self.module.affine_constant_expr(23)
+    a2 = self.module.affine_constant_expr(44)
+    s0 = self.module.affine_symbol_expr(0)
+    aMap1 = self.module.affine_map(2, 0, [a1, a2, s0])
+    aMap2 = self.module.affine_constant_map(42)
+    affineAttr1 = self.module.affineMapAttr(aMap1)
+    affineAttr2 = self.module.affineMapAttr(aMap2)
+
+    t = self.module.make_memref_type(self.f32Type, [10])
+    t_with_attr = t({
+        "affine_attr_1": affineAttr1,
+        "affine_attr_2": affineAttr2
+    })
+
+    f = self.module.declare_function("foo", [t, t_with_attr], [])
+    printWithCurrentFunctionName(str(self.module))
+    # CHECK-LABEL: testFunctionDeclarationWithAffineAttr
+    #       CHECK:  func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42)})
+
+  def testFunctionDeclarationWithArrayAttr(self):
+    self.setUp()
+    arrayAttr = self.module.arrayAttr([
+        self.module.integerAttr(self.i32Type, 43),
+        self.module.integerAttr(self.i32Type, 33),
+    ])
+    t = self.module.make_memref_type(self.f32Type, [10])
+    t_with_attr = t({"array_attr": arrayAttr})
+
+    f = self.module.declare_function("foo", [t, t_with_attr], [])
+    printWithCurrentFunctionName(str(self.module))
+    # CHECK-LABEL: testFunctionDeclarationWithArrayAttr
+    #       CHECK: func @foo(memref<10xf32>, memref<10xf32> {array_attr = [43 : i32, 33 : i32]})
+
   def testFunctionMultiple(self):
     self.setUp()
     with self.module.function_context("foo", [], []):

From 6915969812d3720cbd9710d3fb634ee751d5d0f8 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 10:11:40 -0800
Subject: [PATCH 220/279] Add Python bindings for affine expressions with
 binary operators.

PiperOrigin-RevId: 283569325
Change-Id: I94cefc582362b88c3c2167b77452665a90522e24
---
 third_party/mlir/bindings/python/pybind.cpp   | 63 +++++++++++++++++++
 .../mlir/bindings/python/test/test_py2and3.py | 12 +++-
 2 files changed, 73 insertions(+), 2 deletions(-)

diff --git a/third_party/mlir/bindings/python/pybind.cpp b/third_party/mlir/bindings/python/pybind.cpp
index 90b24cdbf4c..b1be0d21336 100644
--- a/third_party/mlir/bindings/python/pybind.cpp
+++ b/third_party/mlir/bindings/python/pybind.cpp
@@ -207,6 +207,9 @@ struct PythonMLIRModule {
   // Creates an affine symbol expression.
   PythonAffineExpr affineSymbolExpr(unsigned position);
 
+  // Creates an affine dimension expression.
+  PythonAffineExpr affineDimExpr(unsigned position);
+
   // Creates a single constant result affine map.
   PythonAffineMap affineConstantMap(int64_t value);
 
@@ -565,6 +568,8 @@ struct PythonAffineExpr {
   operator AffineExpr() const { return affine_expr; }
   operator AffineExpr &() { return affine_expr; }
 
+  AffineExpr get() const { return affine_expr; }
+
   std::string str() const {
     std::string res;
     llvm::raw_string_ostream os(res);
@@ -724,6 +729,10 @@ PythonAffineExpr PythonMLIRModule::affineSymbolExpr(unsigned position) {
   return PythonAffineExpr(getAffineSymbolExpr(position, &mlirContext));
 }
 
+PythonAffineExpr PythonMLIRModule::affineDimExpr(unsigned position) {
+  return PythonAffineExpr(getAffineDimExpr(position, &mlirContext));
+}
+
 PythonAffineMap PythonMLIRModule::affineConstantMap(int64_t value) {
   return PythonAffineMap(AffineMap::getConstantMap(value, &mlirContext));
 }
@@ -937,6 +946,8 @@ PYBIND11_MODULE(pybind, m) {
            "Returns an affine constant expression.")
       .def("affine_symbol_expr", &PythonMLIRModule::affineSymbolExpr,
            "Returns an affine symbol expression.")
+      .def("affine_dim_expr", &PythonMLIRModule::affineDimExpr,
+           "Returns an affine dim expression.")
       .def("affine_constant_map", &PythonMLIRModule::affineConstantMap,
            "Returns an affine map with single constant result.")
       .def("affine_map", &PythonMLIRModule::affineMap, "Returns an affine map.",
@@ -1054,6 +1065,58 @@ PYBIND11_MODULE(pybind, m) {
   py::class_(m, "AffineExpr",
                                "A wrapper around mlir::AffineExpr")
       .def(py::init())
+      .def("__add__",
+           [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() + rhs);
+           })
+      .def("__add__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() + rhs.get());
+           })
+      .def("__neg__",
+           [](PythonAffineExpr lhs) -> PythonAffineExpr {
+             return PythonAffineExpr(-lhs.get());
+           })
+      .def("__sub__",
+           [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() - rhs);
+           })
+      .def("__sub__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() - rhs.get());
+           })
+      .def("__mul__",
+           [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() * rhs);
+           })
+      .def("__mul__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() * rhs.get());
+           })
+      .def("__floordiv__",
+           [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get().floorDiv(rhs));
+           })
+      .def("__floordiv__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get().floorDiv(rhs.get()));
+           })
+      .def("ceildiv",
+           [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get().ceilDiv(rhs));
+           })
+      .def("ceildiv",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get().ceilDiv(rhs.get()));
+           })
+      .def("__mod__",
+           [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() % rhs);
+           })
+      .def("__mod__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() % rhs.get());
+           })
       .def("__str__", &PythonAffineExpr::str);
 
   py::class_(m, "AffineMap",
diff --git a/third_party/mlir/bindings/python/test/test_py2and3.py b/third_party/mlir/bindings/python/test/test_py2and3.py
index cd4d7bfcecd..678e5023173 100644
--- a/third_party/mlir/bindings/python/test/test_py2and3.py
+++ b/third_party/mlir/bindings/python/test/test_py2and3.py
@@ -289,22 +289,30 @@ class EdscTest:
     self.setUp()
     a1 = self.module.affine_constant_expr(23)
     a2 = self.module.affine_constant_expr(44)
+    a3 = self.module.affine_dim_expr(1)
     s0 = self.module.affine_symbol_expr(0)
     aMap1 = self.module.affine_map(2, 0, [a1, a2, s0])
     aMap2 = self.module.affine_constant_map(42)
+    aMap3 = self.module.affine_map(
+        2, 0,
+        [a1 + a2 * a3, a1 // a3 % a2,
+         a1.ceildiv(a2), a1 - 2, a2 * 2, -a3])
+
     affineAttr1 = self.module.affineMapAttr(aMap1)
     affineAttr2 = self.module.affineMapAttr(aMap2)
+    affineAttr3 = self.module.affineMapAttr(aMap3)
 
     t = self.module.make_memref_type(self.f32Type, [10])
     t_with_attr = t({
         "affine_attr_1": affineAttr1,
-        "affine_attr_2": affineAttr2
+        "affine_attr_2": affineAttr2,
+        "affine_attr_3": affineAttr3,
     })
 
     f = self.module.declare_function("foo", [t, t_with_attr], [])
     printWithCurrentFunctionName(str(self.module))
     # CHECK-LABEL: testFunctionDeclarationWithAffineAttr
-    #       CHECK:  func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42)})
+    #       CHECK:  func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42), affine_attr_3 = (d0, d1) -> (d1 * 44 + 23, (23 floordiv d1) mod 44, 1, 21, 88, -d1)})
 
   def testFunctionDeclarationWithArrayAttr(self):
     self.setUp()

From b6e1920d9cba3801470dddc895538f2e05a77e4a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 10:20:37 -0800
Subject: [PATCH 221/279] Convert MemRefType to a linearized array in SPIR-V
 lowering.

The SPIR-V lowering used nested !spv.arrays to represented
multi-dimensional arrays, with the hope that in-conjunction with the
layout annotations, the shape and layout of memref can be represented
directly. It is unclear though how portable this representation will
end up being. It will rely on driver compilers implementing complex
index computations faithfully. A more portable approach is to use
linearized arrays to represent memrefs and explicitly instantiate all
the index computation in SPIR-V. This gives added benefit that we can
further optimize the generated code in MLIR before generating the
SPIR-V binary.

PiperOrigin-RevId: 283571167
Change-Id: Ib3bb989e60ebc752b04d078f72f88f4832821ed4
---
 .../ConvertStandardToSPIRV.cpp                | 94 +++++++++++--------
 .../mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp  | 74 ++++++++-------
 2 files changed, 98 insertions(+), 70 deletions(-)

diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 4a3d25fbd38..ee2dfedc15b 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -28,6 +28,48 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Utility functions for operation conversion
+//===----------------------------------------------------------------------===//
+
+/// Performs the index computation to get to the element pointed to by
+/// `indices` using the layout map of `baseType`.
+
+// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
+// MemRefType with AffineMap that has static strides. Handle dynamic strides
+spirv::AccessChainOp getElementPtr(OpBuilder &builder,
+                                   SPIRVTypeConverter &typeConverter,
+                                   Location loc, MemRefType origBaseType,
+                                   Value *basePtr, ArrayRef indices) {
+  // Get base and offset of the MemRefType and verify they are static.
+  int64_t offset;
+  SmallVector strides;
+  if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
+      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+    return nullptr;
+  }
+
+  auto indexType = typeConverter.getIndexType(builder.getContext());
+
+  Value *ptrLoc = nullptr;
+  assert(indices.size() == strides.size());
+  for (auto index : enumerate(indices)) {
+    Value *strideVal = builder.create(
+        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+    Value *update =
+        builder.create(loc, strideVal, index.value());
+    ptrLoc =
+        (ptrLoc ? builder.create(loc, ptrLoc, update).getResult()
+                : update);
+  }
+  SmallVector linearizedIndices;
+  // Add a '0' at the start to index into the struct.
+  linearizedIndices.push_back(builder.create(
+      loc, indexType, IntegerAttr::get(indexType, 0)));
+  linearizedIndices.push_back(ptrLoc);
+  return builder.create(loc, basePtr, linearizedIndices);
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -38,6 +80,7 @@ namespace {
 /// operation. Since IndexType is not used within SPIR-V dialect, this needs
 /// special handling to make sure the result type and the type of the value
 /// attribute are consistent.
+// TODO(ravishankarm) : This should be moved into DRR.
 class ConstantIndexOpConversion final : public SPIRVOpLowering {
 public:
   using SPIRVOpLowering::SPIRVOpLowering;
@@ -112,6 +155,7 @@ public:
 /// the type of the return value of the replacement operation differs from
 /// that of the replaced operation. This is not handled in tablegen-based
 /// pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
 template 
 class IntegerOpConversion final : public SPIRVOpLowering {
 public:
@@ -128,36 +172,10 @@ public:
   }
 };
 
-// If 'basePtr' is the result of lowering a value of MemRefType, and 'indices'
-// are the indices used to index into the original value (for load/store),
-// perform the equivalent address calculation in SPIR-V.
-spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc,
-                                   Value *basePtr, ArrayRef indices,
-                                   SPIRVTypeConverter &typeConverter) {
-  // MemRefType is converted to a
-  // spirv::StructType>>
-  auto ptrType = basePtr->getType().cast();
-  (void)ptrType;
-  auto structType = ptrType.getPointeeType().cast();
-  (void)structType;
-  assert(structType.getNumElements() == 1);
-  auto indexType = typeConverter.getIndexType(builder.getContext());
-
-  // Need to add a '0' at the beginning of the index list for accessing into the
-  // struct that wraps the nested array types.
-  Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
-  SmallVector accessIndices;
-  accessIndices.reserve(1 + indices.size());
-  accessIndices.push_back(zero);
-  accessIndices.append(indices.begin(), indices.end());
-  return builder.create(loc, basePtr, accessIndices);
-}
-
 /// Convert load -> spv.LoadOp. The operands of the replaced operation are of
 /// IndexType while that of the replacement operation are of type i32. This is
 /// not supported in tablegen based pattern specification.
-// TODO(ravishankarm) : These could potentially be templated on the operation
-// being converted, since the same logic should work for linalg.load.
+// TODO(ravishankarm) : This should be moved into DRR.
 class LoadOpConversion final : public SPIRVOpLowering {
 public:
   using SPIRVOpLowering::SPIRVOpLowering;
@@ -166,9 +184,9 @@ public:
   matchAndRewrite(LoadOp loadOp, ArrayRef operands,
                   ConversionPatternRewriter &rewriter) const override {
     LoadOpOperandAdaptor loadOperands(operands);
-    auto basePtr = loadOperands.memref();
-    auto loadPtr = getElementPtr(rewriter, loadOp.getLoc(), basePtr,
-                                 loadOperands.indices(), typeConverter);
+    auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
+                                 loadOp.memref()->getType().cast(),
+                                 loadOperands.memref(), loadOperands.indices());
     rewriter.replaceOpWithNewOp(loadOp, loadPtr,
                                                /*memory_access =*/nullptr,
                                                /*alignment =*/nullptr);
@@ -177,6 +195,7 @@ public:
 };
 
 /// Convert return -> spv.Return.
+// TODO(ravishankarm) : This should be moved into DRR.
 class ReturnToSPIRVConversion final : public SPIRVOpLowering {
 public:
   using SPIRVOpLowering::SPIRVOpLowering;
@@ -193,6 +212,7 @@ public:
 };
 
 /// Convert select -> spv.Select
+// TODO(ravishankarm) : This should be moved into DRR.
 class SelectOpConversion final : public SPIRVOpLowering {
 public:
   using SPIRVOpLowering::SPIRVOpLowering;
@@ -210,8 +230,7 @@ public:
 /// Convert store -> spv.StoreOp. The operands of the replaced operation are
 /// of IndexType while that of the replacement operation are of type i32. This
 /// is not supported in tablegen based pattern specification.
-// TODO(ravishankarm) : These could potentially be templated on the operation
-// being converted, since the same logic should work for linalg.store.
+// TODO(ravishankarm) : This should be moved into DRR.
 class StoreOpConversion final : public SPIRVOpLowering {
 public:
   using SPIRVOpLowering::SPIRVOpLowering;
@@ -220,11 +239,12 @@ public:
   matchAndRewrite(StoreOp storeOp, ArrayRef operands,
                   ConversionPatternRewriter &rewriter) const override {
     StoreOpOperandAdaptor storeOperands(operands);
-    auto value = storeOperands.value();
-    auto basePtr = storeOperands.memref();
-    auto storePtr = getElementPtr(rewriter, storeOp.getLoc(), basePtr,
-                                  storeOperands.indices(), typeConverter);
-    rewriter.replaceOpWithNewOp(storeOp, storePtr, value,
+    auto storePtr =
+        getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
+                      storeOp.memref()->getType().cast(),
+                      storeOperands.memref(), storeOperands.indices());
+    rewriter.replaceOpWithNewOp(storeOp, storePtr,
+                                                storeOperands.value(),
                                                 /*memory_access =*/nullptr,
                                                 /*alignment =*/nullptr);
     return matchSuccess();
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index baa9ed305aa..e3b550223e5 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -86,6 +86,33 @@ static Optional getTypeNumBytes(Type t) {
     return integerType.getWidth() / 8;
   } else if (auto floatType = t.dyn_cast()) {
     return floatType.getWidth() / 8;
+  } else if (auto memRefType = t.dyn_cast()) {
+    // TODO: Layout should also be controlled by the ABI attributes. For now
+    // using the layout from MemRef.
+    int64_t offset;
+    SmallVector strides;
+    if (!memRefType.hasStaticShape() ||
+        failed(getStridesAndOffset(memRefType, strides, offset))) {
+      return llvm::None;
+    }
+    // To get the size of the memref object in memory, the total size is the
+    // max(stride * dimension-size) computed for all dimensions times the size
+    // of the element.
+    auto elementSize = getTypeNumBytes(memRefType.getElementType());
+    if (!elementSize) {
+      return llvm::None;
+    }
+    auto dims = memRefType.getShape();
+    if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
+        offset == MemRefType::getDynamicStrideOrOffset() ||
+        llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+      return llvm::None;
+    }
+    int64_t memrefSize = -1;
+    for (auto shape : enumerate(dims)) {
+      memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
+    }
+    return (offset + memrefSize) * elementSize.getValue();
   }
   // TODO: Add size computation for other types.
   return llvm::None;
@@ -120,40 +147,21 @@ static Type convertStdType(Type type) {
     if (!elementSize) {
       return Type();
     }
-
-    if (!memRefType.hasStaticShape()) {
-      // TODO(ravishankarm) : Handle dynamic shapes.
-      return Type();
+    // TODO(ravishankarm) : Handle dynamic shapes.
+    if (memRefType.hasStaticShape()) {
+      auto arraySize = getTypeNumBytes(memRefType);
+      if (!arraySize) {
+        return Type();
+      }
+      auto arrayType = spirv::ArrayType::get(
+          elementType, arraySize.getValue() / elementSize.getValue(),
+          elementSize.getValue());
+      auto structType = spirv::StructType::get(arrayType, 0);
+      // For now initialize the storage class to StorageBuffer. This will be
+      // updated later based on whats passed in w.r.t to the ABI attributes.
+      return spirv::PointerType::get(structType,
+                                     spirv::StorageClass::StorageBuffer);
     }
-
-    // Get the strides and offset.
-    int64_t offset;
-    SmallVector strides;
-    if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
-        offset == MemRefType::getDynamicStrideOrOffset() ||
-        llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
-      // TODO(ravishankarm) : Handle dynamic strides and offsets.
-      return Type();
-    }
-
-    // Convert to a multi-dimensional spv.array if size is known.
-    auto shape = memRefType.getShape();
-    assert(shape.size() == strides.size());
-    Type arrayType = elementType;
-    // TODO(antiagainst): Introduce layout as part of the shader ABI to have
-    // better separate of concerns.
-    for (int i = shape.size(); i > 0; --i) {
-      arrayType = spirv::ArrayType::get(
-          arrayType, shape[i - 1], strides[i - 1] * elementSize.getValue());
-    }
-
-    // For the offset, need to wrap the array in a struct.
-    auto structType =
-        spirv::StructType::get(arrayType, offset * elementSize.getValue());
-    // For now initialize the storage class to StorageBuffer. This will be
-    // updated later based on whats passed in w.r.t to the ABI attributes.
-    return spirv::PointerType::get(structType,
-                                   spirv::StorageClass::StorageBuffer);
   }
 
   return Type();

From 52d84416cbe8e30a58b0395985a7341fb8f96b1a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 10:25:30 -0800
Subject: [PATCH 222/279] Use separate allocator for cached prepacked matrix
 allocations.

This CL splits out the SystemAlignedAlloc/Free functions so that they are independently usable in this way.

ruy::Allocator is a highly specialized allocator designed for the hot path of multiple gemm's.

The use case of cached pre-packing has a very different set of tradeoffs.

PiperOrigin-RevId: 283572179
Change-Id: I510afc80fce6b591e63dc120869ab42ff7cd605f
---
 tensorflow/lite/experimental/ruy/allocator.cc |  6 +-
 tensorflow/lite/experimental/ruy/allocator.h  | 75 ++++++++++---------
 .../lite/experimental/ruy/prepacked_cache.cc  | 13 +++-
 .../lite/experimental/ruy/prepacked_cache.h   | 41 ++--------
 4 files changed, 57 insertions(+), 78 deletions(-)

diff --git a/tensorflow/lite/experimental/ruy/allocator.cc b/tensorflow/lite/experimental/ruy/allocator.cc
index 60a905136fd..8c4536bdeb1 100644
--- a/tensorflow/lite/experimental/ruy/allocator.cc
+++ b/tensorflow/lite/experimental/ruy/allocator.cc
@@ -26,19 +26,19 @@ namespace ruy {
 
 namespace detail {
 
-void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
+void *AlignedAllocator::SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
 #ifdef _WIN32
   return _aligned_malloc(num_bytes, kAlignment);
 #else
   void *ptr;
-  if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) {
+  if (posix_memalign(&ptr, kAlignment, num_bytes)) {
     return nullptr;
   }
   return ptr;
 #endif
 }
 
-void SystemAlignedFree(void *ptr) {
+void AlignedAllocator::SystemAlignedFree(void *ptr) {
 #ifdef _WIN32
   _aligned_free(ptr);
 #else
diff --git a/tensorflow/lite/experimental/ruy/allocator.h b/tensorflow/lite/experimental/ruy/allocator.h
index 2f5c98d6870..f233090ce49 100644
--- a/tensorflow/lite/experimental/ruy/allocator.h
+++ b/tensorflow/lite/experimental/ruy/allocator.h
@@ -34,49 +34,38 @@ inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) {
   return reinterpret_cast(addr);
 }
 
-// Minimum alignment for blocks.
-//
-// Considerations:
-//  - This needs to be at least the alignment of any usual data type.
-//  - It's useful that this is at least the size of a cache line to limit
-//    possible cache side effects (if only on performance behavior).
-//  - It's useful that this is at least the size of SIMD registers, as
-//    some SIMD instruction sets have at least performance behavior
-//    differences (e.g. NEON) or even different requirements (e.g. SSE)
-//    based on that.
-//  - It's useful that this is at least the size of an "exclusive reservation
-//    granule" on ARM, meaning that if we use this Allocator to allocate
-//    an atomic variable, there will be no side effects from other things
-//    contending for exclusive/atomic memory accesses to it. While the
-//    ARM reference manual mentions that this granule size may be as large
-//    as 2048 bytes, in practice we observe it to be 64 bytes. It can
-//    be queried cheaply, at runtime, from userspace, if needed.
-static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64;
-
-// Primitive allocation functions obtaining aligned memory from the
-// operating system.
-void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
-void SystemAlignedFree(void* ptr);
-
-// Specialized allocator designed to converge to a steady-state where all
+// Simple allocator designed to converge to a steady-state where all
 // allocations are bump-ptr allocations from an already-allocated buffer.
 //
 // To support these constraints, this allocator only supports two
 // operations.
 // - AllocateAlignedBytes: allocates a pointer to storage of a specified
-// size, which must be aligned to kMinimumBlockAlignment.
+// size, which must be aligned to kAlignment.
 // - FreeAll: frees all previous allocations (but retains the internal
 // buffer to minimize future calls into the system allocator).
 //
-// This class is specialized for supporting just those two operations
-// under this specific steady-state usage pattern. Extending this class
-// with new allocation interfaces that don't fit that pattern is probably not
-// the right choice. Instead, build a new class on top of
-// SystemAlignedAlloc/SystemAlignedFree.
-//
 // All operations happen on aligned blocks for simplicity.
 class AlignedAllocator {
  public:
+  // Alignment of allocated blocks.
+  //
+  // Considerations:
+  //  - This needs to be at least the alignment of any usual data type.
+  //  - It's useful that this is at least the size of a cache line to limit
+  //    possible cache side effects (if only on performance behavior).
+  //  - It's useful that this is at least the size of SIMD registers, as
+  //    some SIMD instruction sets have at least performance behavior
+  //    differences (e.g. NEON) or even different requirements (e.g. SSE)
+  //    based on that.
+  //  - It's useful that this is at least the size of an "exclusive reservation
+  //    granule" on ARM, meaning that if we use this Allocator to allocate
+  //    an atomic variable, there will be no side effects from other things
+  //    contending for exclusive/atomic memory accesses to it. While the
+  //    ARM reference manual mentions that this granule size may be as large
+  //    as 2048 bytes, in practice we observe it to be 64 bytes. It can
+  //    be queried cheaply, at runtime, from userspace, if needed.
+  static constexpr std::ptrdiff_t kAlignment = 64;
+
   void operator=(const AlignedAllocator&) = delete;
   ~AlignedAllocator() {
     FreeAll();
@@ -85,7 +74,7 @@ class AlignedAllocator {
 
   void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) {
     RUY_DCHECK_GT(num_bytes, 0);
-    RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0);
+    RUY_DCHECK((num_bytes & (kAlignment - 1)) == 0);
     if (void* p = AllocateFast(num_bytes)) {
       return p;
     }
@@ -116,7 +105,17 @@ class AlignedAllocator {
     fallback_blocks_total_size_ = 0;
   }
 
- private:
+  void FreeOne(void* ptr) {
+    for (auto p = fallback_blocks_.begin(); p != fallback_blocks_.end(); ++p) {
+      if (*p == ptr) {
+        SystemAlignedFree(ptr);
+        fallback_blocks_.erase(p);
+        return;
+      }
+    }
+    RUY_DCHECK(false);  // Trying to free pointer we did not allocate.
+  }
+
   void* AllocateFast(std::ptrdiff_t num_bytes) {
     if (current_ + num_bytes > size_) {
       return nullptr;
@@ -133,6 +132,12 @@ class AlignedAllocator {
     return p;
   }
 
+ private:
+  // Primitive allocation functions obtaining aligned memory from the
+  // operating system.
+  void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
+  void SystemAlignedFree(void* ptr);
+
   // Theory of operation:
   //
   // - ptr_, current_, and size_ implement a basic bump-ptr allocator.
@@ -166,7 +171,7 @@ class Allocator {
       return nullptr;
     }
     return aligned.AllocateAlignedBytes(
-        round_up_pot(num_bytes, detail::kMinimumBlockAlignment));
+        round_up_pot(num_bytes, detail::AlignedAllocator::kAlignment));
   }
   template 
   void Allocate(std::ptrdiff_t count, Pointer* out) {
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.cc b/tensorflow/lite/experimental/ruy/prepacked_cache.cc
index 2bd23f834c4..93fc4363044 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache.cc
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache.cc
@@ -58,14 +58,19 @@ void PrepackedCache::EjectOne() {
   PrepackedMatrix &pmatrix = oldest->second.first;
   cache_size_ -= pmatrix.data_size;
   cache_size_ -= pmatrix.sums_size;
-  allocator_.Free(pmatrix.data);
-  allocator_.Free(pmatrix.sums);
+  allocator_.FreeOne(pmatrix.data);
+  allocator_.FreeOne(pmatrix.sums);
   cache_.erase(oldest);
 }
 
 void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) {
-  pmatrix->data = allocator_.Alloc(pmatrix->data_size);
-  pmatrix->sums = allocator_.Alloc(pmatrix->sums_size);
+  pmatrix->data = AllocateBytes(pmatrix->data_size);
+  pmatrix->sums = AllocateBytes(pmatrix->sums_size);
+}
+
+void *PrepackedCache::AllocateBytes(std::ptrdiff_t num_bytes) {
+  // Force system allocation for now to enable easy ejections.
+  return allocator_.AllocateSlow(num_bytes);
 }
 
 void PrepackedCache::DoInsert(const CacheKey &key,
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.h b/tensorflow/lite/experimental/ruy/prepacked_cache.h
index 9c77c48cf69..053108e61ed 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache.h
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache.h
@@ -16,7 +16,6 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_
 
-#include 
 #include 
 #include 
 #include 
@@ -28,40 +27,6 @@ limitations under the License.
 
 namespace ruy {
 
-namespace detail {
-
-// Tracks a set of blocks allocated from the underlying system allocator.
-class SystemBlockAllocator {
- public:
-  void *Alloc(std::ptrdiff_t num_bytes) {
-    void *p = detail::SystemAlignedAlloc(num_bytes);
-    blocks_.push_back(p);
-    return p;
-  }
-
-  void Free(void *block) {
-    for (auto it = blocks_.begin(); it != blocks_.end(); ++it) {
-      if (*it == block) {
-        detail::SystemAlignedFree(block);
-        blocks_.erase(it);
-        return;
-      }
-    }
-    RUY_DCHECK(false);  // Trying to free pointer we did not allocate.
-  }
-
-  ~SystemBlockAllocator() {
-    for (void *block : blocks_) {
-      detail::SystemAlignedFree(block);
-    }
-  }
-
- private:
-  std::vector blocks_;
-};
-
-}  // namespace detail
-
 enum CachePolicy { kNoCache, kCacheLHSOnGemV };
 
 // "Low effort" Least Recently Used Cache for Prepacked Matrices
@@ -115,8 +80,12 @@ class PrepackedCache {
 
  private:
   void EjectOne();
+  void *AllocateBytes(std::ptrdiff_t num_bytes);
   void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix);
-  detail::SystemBlockAllocator allocator_;
+  // Since this cache is used in the context of "pre-packing", we need to
+  // handle allocating the space for the packed matrix ourselves, so we need
+  // our own allocator.
+  AlignedAllocator allocator_;
   std::map cache_;
   const int32_t ejection_threshold_;
   size_t cache_size_;

From 9cf5479f0e036fdc77f8d890696d1475726e3b23 Mon Sep 17 00:00:00 2001
From: Sean Silva 
Date: Tue, 3 Dec 2019 10:42:15 -0800
Subject: [PATCH 223/279] Resubmit of http://cl/283555950 with fix for win32.

PiperOrigin-RevId: 283575731
Change-Id: I88bae7526cbd795ead7ce051de3ba4a1d865d6ab
---
 tensorflow/lite/experimental/ruy/allocator.cc |  8 +-
 tensorflow/lite/experimental/ruy/allocator.h  | 75 +++++++++----------
 .../lite/experimental/ruy/prepacked_cache.cc  | 13 +---
 .../lite/experimental/ruy/prepacked_cache.h   | 41 ++++++++--
 4 files changed, 79 insertions(+), 58 deletions(-)

diff --git a/tensorflow/lite/experimental/ruy/allocator.cc b/tensorflow/lite/experimental/ruy/allocator.cc
index 8c4536bdeb1..d702f70e9fb 100644
--- a/tensorflow/lite/experimental/ruy/allocator.cc
+++ b/tensorflow/lite/experimental/ruy/allocator.cc
@@ -26,19 +26,19 @@ namespace ruy {
 
 namespace detail {
 
-void *AlignedAllocator::SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
+void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
 #ifdef _WIN32
-  return _aligned_malloc(num_bytes, kAlignment);
+  return _aligned_malloc(num_bytes, kMinimumBlockAlignment);
 #else
   void *ptr;
-  if (posix_memalign(&ptr, kAlignment, num_bytes)) {
+  if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) {
     return nullptr;
   }
   return ptr;
 #endif
 }
 
-void AlignedAllocator::SystemAlignedFree(void *ptr) {
+void SystemAlignedFree(void *ptr) {
 #ifdef _WIN32
   _aligned_free(ptr);
 #else
diff --git a/tensorflow/lite/experimental/ruy/allocator.h b/tensorflow/lite/experimental/ruy/allocator.h
index f233090ce49..2f5c98d6870 100644
--- a/tensorflow/lite/experimental/ruy/allocator.h
+++ b/tensorflow/lite/experimental/ruy/allocator.h
@@ -34,38 +34,49 @@ inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) {
   return reinterpret_cast(addr);
 }
 
-// Simple allocator designed to converge to a steady-state where all
+// Minimum alignment for blocks.
+//
+// Considerations:
+//  - This needs to be at least the alignment of any usual data type.
+//  - It's useful that this is at least the size of a cache line to limit
+//    possible cache side effects (if only on performance behavior).
+//  - It's useful that this is at least the size of SIMD registers, as
+//    some SIMD instruction sets have at least performance behavior
+//    differences (e.g. NEON) or even different requirements (e.g. SSE)
+//    based on that.
+//  - It's useful that this is at least the size of an "exclusive reservation
+//    granule" on ARM, meaning that if we use this Allocator to allocate
+//    an atomic variable, there will be no side effects from other things
+//    contending for exclusive/atomic memory accesses to it. While the
+//    ARM reference manual mentions that this granule size may be as large
+//    as 2048 bytes, in practice we observe it to be 64 bytes. It can
+//    be queried cheaply, at runtime, from userspace, if needed.
+static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64;
+
+// Primitive allocation functions obtaining aligned memory from the
+// operating system.
+void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
+void SystemAlignedFree(void* ptr);
+
+// Specialized allocator designed to converge to a steady-state where all
 // allocations are bump-ptr allocations from an already-allocated buffer.
 //
 // To support these constraints, this allocator only supports two
 // operations.
 // - AllocateAlignedBytes: allocates a pointer to storage of a specified
-// size, which must be aligned to kAlignment.
+// size, which must be aligned to kMinimumBlockAlignment.
 // - FreeAll: frees all previous allocations (but retains the internal
 // buffer to minimize future calls into the system allocator).
 //
+// This class is specialized for supporting just those two operations
+// under this specific steady-state usage pattern. Extending this class
+// with new allocation interfaces that don't fit that pattern is probably not
+// the right choice. Instead, build a new class on top of
+// SystemAlignedAlloc/SystemAlignedFree.
+//
 // All operations happen on aligned blocks for simplicity.
 class AlignedAllocator {
  public:
-  // Alignment of allocated blocks.
-  //
-  // Considerations:
-  //  - This needs to be at least the alignment of any usual data type.
-  //  - It's useful that this is at least the size of a cache line to limit
-  //    possible cache side effects (if only on performance behavior).
-  //  - It's useful that this is at least the size of SIMD registers, as
-  //    some SIMD instruction sets have at least performance behavior
-  //    differences (e.g. NEON) or even different requirements (e.g. SSE)
-  //    based on that.
-  //  - It's useful that this is at least the size of an "exclusive reservation
-  //    granule" on ARM, meaning that if we use this Allocator to allocate
-  //    an atomic variable, there will be no side effects from other things
-  //    contending for exclusive/atomic memory accesses to it. While the
-  //    ARM reference manual mentions that this granule size may be as large
-  //    as 2048 bytes, in practice we observe it to be 64 bytes. It can
-  //    be queried cheaply, at runtime, from userspace, if needed.
-  static constexpr std::ptrdiff_t kAlignment = 64;
-
   void operator=(const AlignedAllocator&) = delete;
   ~AlignedAllocator() {
     FreeAll();
@@ -74,7 +85,7 @@ class AlignedAllocator {
 
   void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) {
     RUY_DCHECK_GT(num_bytes, 0);
-    RUY_DCHECK((num_bytes & (kAlignment - 1)) == 0);
+    RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0);
     if (void* p = AllocateFast(num_bytes)) {
       return p;
     }
@@ -105,17 +116,7 @@ class AlignedAllocator {
     fallback_blocks_total_size_ = 0;
   }
 
-  void FreeOne(void* ptr) {
-    for (auto p = fallback_blocks_.begin(); p != fallback_blocks_.end(); ++p) {
-      if (*p == ptr) {
-        SystemAlignedFree(ptr);
-        fallback_blocks_.erase(p);
-        return;
-      }
-    }
-    RUY_DCHECK(false);  // Trying to free pointer we did not allocate.
-  }
-
+ private:
   void* AllocateFast(std::ptrdiff_t num_bytes) {
     if (current_ + num_bytes > size_) {
       return nullptr;
@@ -132,12 +133,6 @@ class AlignedAllocator {
     return p;
   }
 
- private:
-  // Primitive allocation functions obtaining aligned memory from the
-  // operating system.
-  void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
-  void SystemAlignedFree(void* ptr);
-
   // Theory of operation:
   //
   // - ptr_, current_, and size_ implement a basic bump-ptr allocator.
@@ -171,7 +166,7 @@ class Allocator {
       return nullptr;
     }
     return aligned.AllocateAlignedBytes(
-        round_up_pot(num_bytes, detail::AlignedAllocator::kAlignment));
+        round_up_pot(num_bytes, detail::kMinimumBlockAlignment));
   }
   template 
   void Allocate(std::ptrdiff_t count, Pointer* out) {
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.cc b/tensorflow/lite/experimental/ruy/prepacked_cache.cc
index 93fc4363044..2bd23f834c4 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache.cc
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache.cc
@@ -58,19 +58,14 @@ void PrepackedCache::EjectOne() {
   PrepackedMatrix &pmatrix = oldest->second.first;
   cache_size_ -= pmatrix.data_size;
   cache_size_ -= pmatrix.sums_size;
-  allocator_.FreeOne(pmatrix.data);
-  allocator_.FreeOne(pmatrix.sums);
+  allocator_.Free(pmatrix.data);
+  allocator_.Free(pmatrix.sums);
   cache_.erase(oldest);
 }
 
 void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) {
-  pmatrix->data = AllocateBytes(pmatrix->data_size);
-  pmatrix->sums = AllocateBytes(pmatrix->sums_size);
-}
-
-void *PrepackedCache::AllocateBytes(std::ptrdiff_t num_bytes) {
-  // Force system allocation for now to enable easy ejections.
-  return allocator_.AllocateSlow(num_bytes);
+  pmatrix->data = allocator_.Alloc(pmatrix->data_size);
+  pmatrix->sums = allocator_.Alloc(pmatrix->sums_size);
 }
 
 void PrepackedCache::DoInsert(const CacheKey &key,
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.h b/tensorflow/lite/experimental/ruy/prepacked_cache.h
index 053108e61ed..9c77c48cf69 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache.h
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache.h
@@ -16,6 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_
 
+#include 
 #include 
 #include 
 #include 
@@ -27,6 +28,40 @@ limitations under the License.
 
 namespace ruy {
 
+namespace detail {
+
+// Tracks a set of blocks allocated from the underlying system allocator.
+class SystemBlockAllocator {
+ public:
+  void *Alloc(std::ptrdiff_t num_bytes) {
+    void *p = detail::SystemAlignedAlloc(num_bytes);
+    blocks_.push_back(p);
+    return p;
+  }
+
+  void Free(void *block) {
+    for (auto it = blocks_.begin(); it != blocks_.end(); ++it) {
+      if (*it == block) {
+        detail::SystemAlignedFree(block);
+        blocks_.erase(it);
+        return;
+      }
+    }
+    RUY_DCHECK(false);  // Trying to free pointer we did not allocate.
+  }
+
+  ~SystemBlockAllocator() {
+    for (void *block : blocks_) {
+      detail::SystemAlignedFree(block);
+    }
+  }
+
+ private:
+  std::vector blocks_;
+};
+
+}  // namespace detail
+
 enum CachePolicy { kNoCache, kCacheLHSOnGemV };
 
 // "Low effort" Least Recently Used Cache for Prepacked Matrices
@@ -80,12 +115,8 @@ class PrepackedCache {
 
  private:
   void EjectOne();
-  void *AllocateBytes(std::ptrdiff_t num_bytes);
   void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix);
-  // Since this cache is used in the context of "pre-packing", we need to
-  // handle allocating the space for the packed matrix ourselves, so we need
-  // our own allocator.
-  AlignedAllocator allocator_;
+  detail::SystemBlockAllocator allocator_;
   std::map cache_;
   const int32_t ejection_threshold_;
   size_t cache_size_;

From 094da7eaa9d802ee676170f2df58ded528da52d9 Mon Sep 17 00:00:00 2001
From: Bixia Zheng 
Date: Tue, 3 Dec 2019 10:43:27 -0800
Subject: [PATCH 224/279] [TF:MLIR] Add inliner and shape inference to the
 standard pipeline.

The inliner pass is added to the standard pipeline only when the enable_inliner
option is on.

Add enable_inliner knob to RunBridgeWithStandardPipeline.

PiperOrigin-RevId: 283575979
Change-Id: I977a073a96b3a95950b53b8474090b0a1eec594f
---
 .../compiler/mlir/tensorflow/transforms/bridge.cc   |  7 +++++--
 .../compiler/mlir/tensorflow/transforms/bridge.h    |  8 +++++---
 .../mlir/tensorflow/transforms/bridge_pass.cc       |  6 ------
 .../compiler/mlir/tensorflow/transforms/optimize.cc | 13 +++++++++----
 .../compiler/mlir/tensorflow/transforms/passes.h    | 12 +++++++++++-
 tensorflow/compiler/tf2xla/mlir_tf2xla.cc           |  2 +-
 6 files changed, 31 insertions(+), 17 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
index fccedf4057a..d9bae902382 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
@@ -68,14 +68,17 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
 namespace TF {
 
 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
-                                                 bool enable_logging) {
+                                                 bool enable_logging,
+                                                 bool enable_inliner) {
   PassManager bridge(module.getContext());
 
   // Add logger to bridge passmanager.
   if (enable_logging)
     bridge.addInstrumentation(std::make_unique());
 
-  CreateTFStandardPipeline(bridge);
+  StandardPipeline::Options pipeline_options;
+  pipeline_options.enable_inliner.setValue(enable_inliner);
+  CreateTFStandardPipeline(bridge, pipeline_options);
   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
   LogicalResult result = bridge.run(module);
   (void)result;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
index 844b9095dba..ff446af24f5 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
@@ -31,11 +31,13 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging);
 
 namespace TF {
 
-// Run all passes involved in transforming or optimizing an MLIR graph without
+// Runs all passes involved in transforming or optimizing an MLIR graph without
 // any target specialization. When enable_logging is true, enables
-// tensorflow::BridgeLogger.
+// tensorflow::BridgeLogger. When enable_inliner is true, enables the inliner
+// pass.
 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
-                                                 bool enable_logging);
+                                                 bool enable_logging,
+                                                 bool enable_inliner);
 
 }  // namespace TF
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc
index b19bb0f8cd5..0208dc2f579 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc
@@ -29,10 +29,4 @@ mlir::PassPipelineRegistration<> tpu_pipeline(
     "that it is suitable for targeting TPUs.",
     mlir::TFTPU::CreateTPUBridge);
 
-mlir::PassPipelineRegistration<> standard_pipeline(
-    "tf-standard-bridge",
-    "Run all passes involved in transforming or optimizing an MLIR graph"
-    "without any target specialization.",
-    mlir::TF::CreateTFStandardPipeline);
-
 }  // anonymous namespace
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
index c6e3a0ab895..9dd5accc81c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
@@ -45,7 +45,8 @@ struct TFOptimizePass : public FunctionPass {
 }  // namespace
 
 // NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
-void CreateTFStandardPipeline(OpPassManager &pm) {
+void CreateTFStandardPipeline(OpPassManager &pm,
+                              const StandardPipeline::Options &options) {
   OpPassManager &func_pm = pm.nest();
 
   // First operates on the executor dialect:
@@ -59,8 +60,12 @@ void CreateTFStandardPipeline(OpPassManager &pm) {
   // Hopefully there is a single island left, or there wasn't any to begin with.
   // We now run the optimizer which operates mostly inside islands.
   func_pm.addPass(createCanonicalizerPass());
-  func_pm.addPass(CreateTFOptimizePass());
-  func_pm.addPass(createCSEPass());
+  if (options.enable_inliner) {
+    pm.addPass(createInlinerPass());
+  }
+  pm.addNestedPass(CreateTFShapeInferencePass());
+  pm.addNestedPass(CreateTFOptimizePass());
+  pm.addNestedPass(createCSEPass());
 }
 
 std::unique_ptr> CreateTFOptimizePass() {
@@ -70,7 +75,7 @@ std::unique_ptr> CreateTFOptimizePass() {
 static PassRegistration pass("tf-optimize", "Optimizes TF.");
 
 // Registers a pipeline builder function for the default canonicalize/optimizer.
-static mlir::PassPipelineRegistration<> pipeline(
+static mlir::PassPipelineRegistration pipeline(
     "tf-standard-pipeline",
     "Run all the passes involved in transforming/optimizing the graph after "
     "importing into MLIR, without any target specialization.",
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index 7a5c060f5dc..49458c4eac6 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -46,10 +46,20 @@ std::unique_ptr> CreateTFShapeInferencePass();
 // Optimizes Tensorflow graph.
 std::unique_ptr> CreateTFOptimizePass();
 
+class StandardPipeline : public ModulePass {
+ public:
+  struct Options : public PassOptions {
+    Option enable_inliner{*this, "enable-inliner",
+                                llvm::cl::desc("Enable inliner."),
+                                llvm::cl::init(false)};
+  };
+};
+
 // Propagates the pass manager with the passes involved in transforming or
 // optimizing an MLIR graph without any target specialization.
 // NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
-void CreateTFStandardPipeline(OpPassManager& pm);
+void CreateTFStandardPipeline(OpPassManager& pm,
+                              const StandardPipeline::Options& options);
 }  // namespace TF
 
 namespace TFControlFlow {
diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
index 01325af3d39..ddfeb1a6b5a 100644
--- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
@@ -109,7 +109,7 @@ Status ConvertGraphDefToXlaViaMlir(const GraphDef& graph_def,
   AddDevicesToOp(*module, &device_set);
 
   TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
-      *module, /*enable_logging=*/VLOG_IS_ON(1)));
+      *module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));
 
   // Convert the MLIR module to XLA computation. If the input graph can't be
   // lowered down to a single graph node with a single island by the previous

From 8ea0977ab8467684a9a7ef02d15ce48de9d4cc42 Mon Sep 17 00:00:00 2001
From: River Riddle 
Date: Tue, 3 Dec 2019 11:13:39 -0800
Subject: [PATCH 225/279] Allow analyses to provide a hook 'isInvalidated' to
 determine if they are truly invalidated.

The hook has the following form:
*   `bool isInvalidated(const AnalysisManager::PreservedAnalyses &)`

Given a preserved analysis set, the analysis returns true if it should truly be
invalidated. This allows for more fine-tuned invalidation in cases where an
analysis wasn't explicitly marked preserved, but may be preserved(or
invalidated) based upon other properties; such as analyses sets.

PiperOrigin-RevId: 283582889
Change-Id: Ice539bad590ae659d7815e28254173abfe0f2fa0
---
 third_party/mlir/g3doc/WritingAPass.md        | 16 +++++--
 .../mlir/include/mlir/Pass/AnalysisManager.h  | 45 ++++++++++++++++---
 2 files changed, 52 insertions(+), 9 deletions(-)

diff --git a/third_party/mlir/g3doc/WritingAPass.md b/third_party/mlir/g3doc/WritingAPass.md
index df0d153ad1a..1e4564aa21d 100644
--- a/third_party/mlir/g3doc/WritingAPass.md
+++ b/third_party/mlir/g3doc/WritingAPass.md
@@ -116,12 +116,20 @@ the following:
 *   Provide a valid constructor taking an `Operation*`.
 *   Must not modify the given operation.
 
-The base `OperationPass` class provide utilities for querying and preserving
-analyses for the current operation being processed. Using the example passes
-defined above, let's see some examples:
+An analysis may provide additional hooks to control various behavior:
+
+*   `bool isInvalidated(const AnalysisManager::PreservedAnalyses &)`
+
+Given a preserved analysis set, the analysis returns true if it should truly be
+invalidated. This allows for more fine-tuned invalidation in cases where an
+analysis wasn't explicitly marked preserved, but may be preserved(or
+invalidated) based upon other properties such as analyses sets.
 
 ### Querying Analyses
 
+The base `OperationPass` class provide utilities for querying and preserving
+analyses for the current operation being processed.
+
 *   OperationPass automatically provides the following utilities for querying
     analyses:
     *   `getAnalysis<>`
@@ -137,7 +145,7 @@ defined above, let's see some examples:
         -   Get an analysis for a given child operation, constructing it if
             necessary.
 
-A few example usages are shown below:
+Using the example passes defined above, let's see some examples:
 
 ```c++
 /// An interesting analysis.
diff --git a/third_party/mlir/include/mlir/Pass/AnalysisManager.h b/third_party/mlir/include/mlir/Pass/AnalysisManager.h
index 163ecf6356f..6c37223ad91 100644
--- a/third_party/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/third_party/mlir/include/mlir/Pass/AnalysisManager.h
@@ -76,9 +76,36 @@ private:
   SmallPtrSet preservedIDs;
 };
 
+namespace analysis_impl {
+/// Trait to check if T provides a static 'isInvalidated' method.
+template 
+using has_is_invalidated = decltype(std::declval().isInvalidated(
+    std::declval()));
+
+/// Implementation of 'isInvalidated' if the analysis provides a definition.
+template 
+std::enable_if_t::value, bool>
+isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
+  return analysis.isInvalidated(pa);
+}
+/// Default implementation of 'isInvalidated'.
+template 
+std::enable_if_t::value, bool>
+isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
+  return !pa.isPreserved();
+}
+} // end namespace analysis_impl
+
 /// The abstract polymorphic base class representing an analysis.
 struct AnalysisConcept {
   virtual ~AnalysisConcept() = default;
+
+  /// A hook used to query analyses for invalidation. Given a preserved analysis
+  /// set, returns true if it should truly be invalidated. This allows for more
+  /// fine-tuned invalidation in cases where an analysis wasn't explicitly
+  /// marked preserved, but may be preserved(or invalidated) based upon other
+  /// properties such as analyses sets.
+  virtual bool isInvalidated(const PreservedAnalyses &pa) = 0;
 };
 
 /// A derived analysis model used to hold a specific analysis object.
@@ -87,6 +114,12 @@ template  struct AnalysisModel : public AnalysisConcept {
   explicit AnalysisModel(Args &&... args)
       : analysis(std::forward(args)...) {}
 
+  /// A hook used to query analyses for invalidation.
+  bool isInvalidated(const PreservedAnalyses &pa) final {
+    return analysis_impl::isInvalidated(analysis, pa);
+  }
+
+  /// The actual analysis object.
   AnalysisT analysis;
 };
 
@@ -147,11 +180,11 @@ public:
 
   /// Invalidate any cached analyses based upon the given set of preserved
   /// analyses.
-  void invalidate(const detail::PreservedAnalyses &pa) {
-    // Remove any analyses not marked as preserved.
+  void invalidate(const PreservedAnalyses &pa) {
+    // Remove any analyses that were invalidated.
     for (auto it = analyses.begin(), e = analyses.end(); it != e;) {
       auto curIt = it++;
-      if (!pa.isPreserved(curIt->first))
+      if (curIt->second->isInvalidated(pa))
         analyses.erase(curIt);
     }
   }
@@ -170,7 +203,7 @@ struct NestedAnalysisMap {
   Operation *getOperation() const { return analyses.getOperation(); }
 
   /// Invalidate any non preserved analyses.
-  void invalidate(const detail::PreservedAnalyses &pa);
+  void invalidate(const PreservedAnalyses &pa);
 
   /// The cached analyses for nested operations.
   llvm::DenseMap> childAnalyses;
@@ -195,6 +228,8 @@ class AnalysisManager {
                                             const AnalysisManager *>;
 
 public:
+  using PreservedAnalyses = detail::PreservedAnalyses;
+
   // Query for a cached analysis on the given parent operation. The analysis may
   // not exist and if it does it may be out-of-date.
   template 
@@ -240,7 +275,7 @@ public:
   AnalysisManager slice(Operation *op);
 
   /// Invalidate any non preserved analyses,
-  void invalidate(const detail::PreservedAnalyses &pa) { impl->invalidate(pa); }
+  void invalidate(const PreservedAnalyses &pa) { impl->invalidate(pa); }
 
   /// Clear any held analyses.
   void clear() {

From af999a2074b4176b45248f8bcb0e580a3c6e45c3 Mon Sep 17 00:00:00 2001
From: Sean Silva 
Date: Tue, 3 Dec 2019 11:23:48 -0800
Subject: [PATCH 226/279] Verifier: Better error message in case of successor
 operand mismatch.

In particular, print the successor number in the diagnostic.

PiperOrigin-RevId: 283585084
Change-Id: Id008bef94e07af0753fac6655c4ef91e88a793d7
---
 third_party/mlir/lib/IR/Operation.cpp | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/third_party/mlir/lib/IR/Operation.cpp b/third_party/mlir/lib/IR/Operation.cpp
index f0ebd59ab9f..d079033e39b 100644
--- a/third_party/mlir/lib/IR/Operation.cpp
+++ b/third_party/mlir/lib/IR/Operation.cpp
@@ -901,18 +901,21 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
   return success();
 }
 
-static LogicalResult verifyBBArguments(Operation::operand_range operands,
-                                       Block *destBB, Operation *op) {
-  unsigned operandCount = std::distance(operands.begin(), operands.end());
+static LogicalResult verifySuccessor(Operation *op, unsigned succNo) {
+  Operation::operand_range operands = op->getSuccessorOperands(succNo);
+  unsigned operandCount = op->getNumSuccessorOperands(succNo);
+  Block *destBB = op->getSuccessor(succNo);
   if (operandCount != destBB->getNumArguments())
     return op->emitError() << "branch has " << operandCount
-                           << " operands, but target block has "
+                           << " operands for successor #" << succNo
+                           << ", but target block has "
                            << destBB->getNumArguments();
 
   auto operandIt = operands.begin();
   for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
     if ((*operandIt)->getType() != destBB->getArgument(i)->getType())
-      return op->emitError() << "type mismatch in bb argument #" << i;
+      return op->emitError() << "type mismatch for bb argument #" << i
+                             << " of successor #" << succNo;
   }
 
   return success();
@@ -926,7 +929,7 @@ static LogicalResult verifyTerminatorSuccessors(Operation *op) {
     auto *succ = op->getSuccessor(i);
     if (succ->getParent() != parent)
       return op->emitError("reference to block defined in another region");
-    if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op)))
+    if (failed(verifySuccessor(op, i)))
       return failure();
   }
   return success();

From ed4f9bd173be04a8b07112e3ea98ebbfada10fa9 Mon Sep 17 00:00:00 2001
From: Prakalp Srivastava 
Date: Tue, 3 Dec 2019 11:37:36 -0800
Subject: [PATCH 227/279] GetTupleElement export does not require custom
 support.

All operands and attributes of this op can be handled by the generic exporter.

PiperOrigin-RevId: 283588207
Change-Id: I31f8e188a3286866cf2df93a8b079797f6546099
---
 tensorflow/compiler/mlir/xla/ir/hlo_ops.td      | 3 ---
 tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc | 7 -------
 2 files changed, 10 deletions(-)

diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 4fb85f9f6b3..e285b172806 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -437,9 +437,6 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
   let builders = [OpBuilder<
                   "Builder *builder, OperationState &results, "
                   "Value* value, int32_t index">];
-
-  // GetTupleElementOp has special conversion logic to HLO.
-  let hasCustomHLOConverter = 1;
 }
 
 def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index f717c8199fd..93716331d0d 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -499,13 +499,6 @@ LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
   return failure();
 }
 
-LogicalResult ExportXlaOp(GetTupleElementOp op, OpLoweringContext ctx) {
-  auto& value_map = *ctx.values;
-  value_map[op] = xla::GetTupleElement(value_map[op.getOperand()],
-                                       op.index().getSExtValue());
-  return success();
-}
-
 LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
   auto& value_map = *ctx.values;
   value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),

From ea475b0b11d716c3077c269f4acdadaef4c82bfe Mon Sep 17 00:00:00 2001
From: George Karpenkov 
Date: Tue, 3 Dec 2019 11:39:16 -0800
Subject: [PATCH 228/279] [XLA/CPU] Fix race condition in all-reduce CPU
 implementation

We need to drop the reference to the rendezvous object, and *then* wait for all
other threads to do the same, doing it in the opposite order does not make
sense: otherwise one of the threads could run past this point, and attempt to
reuse the rendezvous.
PiperOrigin-RevId: 283588609
Change-Id: I32fa4e5fa326bdc0acdd3c00ea9f1270bf5973c7
---
 .../xla/service/collective_ops_utils.h        | 58 ++++++++++++++-----
 tensorflow/compiler/xla/service/cpu/BUILD     |  2 +
 .../compiler/xla/service/cpu/cpu_runtime.cc   | 27 +++------
 .../xla/service/gpu/nccl_all_reduce_thunk.cc  | 38 ++----------
 4 files changed, 58 insertions(+), 67 deletions(-)

diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h
index f2fa9640f85..2c5f2d64d1f 100644
--- a/tensorflow/compiler/xla/service/collective_ops_utils.h
+++ b/tensorflow/compiler/xla/service/collective_ops_utils.h
@@ -16,6 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
 
+#include 
 #include 
 
 #include "tensorflow/compiler/xla/executable_run_options.h"
@@ -188,6 +189,48 @@ class Rendezvous {
   virtual ~Rendezvous() {}
   explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
 
+  // Submit a participant to the rendezvous. We get the rendezvous from
+  // `rendezvous_getter`, which we can then use to drop the existing reference.
+  static StatusOr SubmitParticipant(
+      std::function>()> rendezvous_getter,
+      AllReduceParticipantData participant) {
+    std::shared_ptr> rendezvous = rendezvous_getter();
+    TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant));
+
+    // Drop our reference to the Rendezvous and wait for all other threads to do
+    // the same.  If we didn't do this, one of the threads could run past this
+    // point, reenter ExecuteOnStream for another all-reduce, and attempt to
+    // reuse the Rendezvous!
+    //
+    // An alternative way of accomplishing this goal would be to implement
+    // RefcountingHashMap::erase() and call it during SubmitParticipant.  But
+    // erase() is deceptively complex to implement correctly.
+    std::shared_ptr blocking_counter = p.second;
+    rendezvous.reset();
+    blocking_counter->DecrementCount();
+    xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
+      return absl::StrFormat(
+          "participant waiting for all threads to drop their reference to the "
+          "rendezvous: %p",
+          rendezvous.get());
+    });
+    return p.first;
+  }
+
+ protected:
+  // Returns domain-specific output O and whether this replica is primary.
+  virtual StatusOr> SubmitParticipantImpl(
+      AllReduceParticipantData participant) = 0;
+
+  virtual void CleanupImpl(O handle, bool is_primary) {}
+
+  tensorflow::mutex mu_;
+
+  bool initialized_ GUARDED_BY(mu_) = false;
+
+  std::vector participants_ GUARDED_BY(mu_);
+
+ private:
   // Runs the all-reduce on the given thread.  If successful, returns
   //  - a handle to the clique that was used, so that the caller may keep the
   //    clique alive if it chooses.
@@ -248,21 +291,6 @@ class Rendezvous {
 
     return std::make_pair(handle, returned_blocking_counter_);
   }
-
- protected:
-  // Returns domain-specific output O and whether this replica is primary.
-  virtual StatusOr> SubmitParticipantImpl(
-      AllReduceParticipantData participant) = 0;
-
-  virtual void CleanupImpl(O handle, bool is_primary) {}
-
-  tensorflow::mutex mu_;
-
-  bool initialized_ GUARDED_BY(mu_) = false;
-
-  std::vector participants_ GUARDED_BY(mu_);
-
- private:
   const RendezvousKey key_;
 
   tensorflow::BlockingCounter all_participants_present_{
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index e3aa1551b8a..af856e92e70 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -489,6 +489,7 @@ cc_library(
         "//tensorflow/compiler/xla:executable_run_options",
         "//tensorflow/compiler/xla:refcounting_hash_map",
         "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
@@ -502,6 +503,7 @@ cc_library(
         "//tensorflow/core/platform:macros",
         "//tensorflow/core/platform:mutex",
         "//tensorflow/core/platform:platform_port",
+        "//tensorflow/core/platform:status",
         "//tensorflow/core/platform:types",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/stream_executor",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 9b3e85427a3..56d663f7b24 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -37,6 +37,7 @@ limitations under the License.
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/stream_executor/device_memory.h"
@@ -415,8 +416,6 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
   xla::RendezvousKey rendezvous_key(run_options->run_id(),
                                     participating_replicas_vec, op_kind, op_id);
 
-  std::shared_ptr rendezvous =
-      GlobalRendezvousMap()[rendezvous_key];
 
   auto shape_str = ShapeString(shape_ptr, shape_length);
   VLOG(2) << "All-reduce input/output shape : " << shape_str;
@@ -431,24 +430,16 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
   participant.device_ordinal = device_ordinal;
   participant.primitive_type = shape.element_type();
   participant.stream = run_options->stream();
-
-  se::DeviceMemoryBase input(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
-  se::DeviceMemoryBase output(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
-  participant.source_data = input;
-  participant.destination_data = output;
+  participant.source_data =
+      se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
+  participant.destination_data =
+      se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
   participant.reduction_kind = static_cast(reduction_kind);
 
-  auto p = rendezvous->SubmitParticipant(participant).ValueOrDie();
-  std::shared_ptr blocking_counter = p.second;
-  blocking_counter->DecrementCount();
-  xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
-    return absl::StrFormat(
-        "participant waiting for all threads to drop their reference to the "
-        "rendezvous: %s",
-        rendezvous_key.ToString());
-  });
-
-  rendezvous.reset();
+  TF_CHECK_OK(
+      CpuAllReduceRendezvous::SubmitParticipant(
+          [&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant)
+          .status());
 }
 
 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
index d74e7f22916..2fb1fc07056 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
@@ -499,11 +499,8 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
   // Find or create the rendezvous for this collective operation.
   RendezvousKey rendezvous_key = RendezvousKey::FromInstruction(
       params.run_id, participating_replicas, hlo_instruction());
-  std::shared_ptr rendezvous =
-      GlobalRendezvousMap()[rendezvous_key];
 
   VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
-          << ", rendezvous: " << rendezvous.get()
           << ", participating replicas: "
           << absl::StrJoin(participating_replicas, ", ");
 
@@ -521,19 +518,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
   participant.reduction_kind = *reduction_kind;
   participant.primitive_type = AllReducePrimitiveType(hlo_instruction());
 
-  // Do the operation.
-  StatusOr,
-                     std::shared_ptr>>
-      result = rendezvous->SubmitParticipant(participant);
-  if (!result.ok()) {
-    VLOG(1) << "NcclAllReduceThunk::ExecuteOnStream failed: "
-            << result.status().ToString();
-    return result.status();
-  }
-
-  std::shared_ptr clique;
-  std::shared_ptr blocking_counter;
-  std::tie(clique, blocking_counter) = std::move(result).ValueOrDie();
+  TF_ASSIGN_OR_RETURN(
+      std::shared_ptr clique,
+      RendezvousNcclAllReduce::SubmitParticipant(
+          [&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant));
 
   // Keep the clique we used alive for as long as this Thunk lives.  Creating
   // new NCCL cliques is expensive, and this is how we avoid thrashing them.
@@ -541,24 +529,6 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
     tensorflow::mutex_lock lock(aux_data_->mu);
     aux_data_->cliques.insert(std::move(clique));
   }
-
-  // Drop our reference to the Rendezvous and wait for all other threads to do
-  // the same.  If we didn't do this, one of the threads could run past this
-  // point, reenter ExecuteOnStream for another all-reduce, and attempt to reuse
-  // the Rendezvous!
-  //
-  // An alternative way of accomplishing this goal would be to implement
-  // RefcountingHashMap::erase() and call it during SubmitParticipant.  But
-  // erase() is deceptively complex to implement correctly.
-  rendezvous.reset();
-  blocking_counter->DecrementCount();
-  WaitAndLogIfStuck(blocking_counter.get(), [&] {
-    return absl::StrFormat(
-        "participant for device ordinal %d, stream %p waiting for "
-        "all threads to drop their reference to the rendezvous: %s",
-        device_ordinal, params.stream, rendezvous_key.ToString());
-  });
-
   return Status::OK();
 }
 

From a33db11a83a79c8529d836409c43d6b915d465bd Mon Sep 17 00:00:00 2001
From: Yu-Cheng Ling 
Date: Tue, 3 Dec 2019 11:46:24 -0800
Subject: [PATCH 229/279] tflite_convert: Always propagate the
 experimental_new_converter flag.

This is a non-functional change for now because the new converter is disblaed
by default. The change is for smoothening the flow when we enable it in the future.

PiperOrigin-RevId: 283590032
Change-Id: I1c599fc3b115dc882057560ff7a43fe52091715d
---
 tensorflow/lite/python/tflite_convert.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py
index d66fe0bb5a9..02a00ea79b6 100644
--- a/tensorflow/lite/python/tflite_convert.py
+++ b/tensorflow/lite/python/tflite_convert.py
@@ -205,8 +205,9 @@ def _convert_tf1_model(flags):
   if flags.conversion_summary_dir:
     converter.conversion_summary_dir = flags.conversion_summary_dir
 
-  if flags.experimental_new_converter:
-    converter.experimental_new_converter = True
+  # TODO(b/145312675): Enable the new converter by default. It requires to
+  # add a new command line argument like `experimental_legacy_converter`.
+  converter.experimental_new_converter = flags.experimental_new_converter
 
   # Convert model.
   output_data = converter.convert()

From 67e6a8e6efb14363ad1d75f8369ba92b6dd68937 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 11:55:09 -0800
Subject: [PATCH 230/279] Add CreateMaskOp to the VectorOps dialect.

PiperOrigin-RevId: 283591888
Change-Id: I38ca86649c59c704900755e46f440fe2a35bdfb3
---
 .../mlir/Dialect/VectorOps/VectorOps.td       | 32 ++++++++++++++++++-
 .../mlir/lib/Dialect/VectorOps/VectorOps.cpp  | 31 ++++++++++++++++++
 2 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index c75f9fe0231..d34fa9a245d 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -627,7 +627,37 @@ def Vector_TypeCastOp :
   }];
 }
 
-// TODO(andydavis) Morph this operation into a Vector_MaskOp.
+// TODO(andydavis) Add constant folding support.
+def Vector_CreateMaskOp :
+  Vector_Op<"create_mask", [NoSideEffect]>,
+    Arguments<(ins Variadic:$operands)>, Results<(outs VectorOf<[I1]>)> {
+  let summary = "creates a vector mask";
+  let description = [{
+    Creates and returns a vector mask where elements of the result vector
+    are set to '0' or '1', based on whether the element indices are contained
+    within a hyper-rectangular region specified by the operands. Specifically,
+    each operand specifies a range [0, operand-value) for a unique dimension in
+    the vector result. The conjunction of the operand ranges define
+    hyper-rectangular region within which elements values are set to 1
+    (otherwise element values are set to 0).
+
+    Example: create a vector mask of size 4x3xi1 where elements in range
+             0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
+
+      %1 = vector.create_mask %c3, %c2 : vector<4x3xi1>
+
+      print %1
+                    columns
+                  0    1    2
+                |------------
+              0 | 1    1    0
+        rows  1 | 1    1    0
+              2 | 1    1    0
+              3 | 0    0    0
+  }];
+}
+
+// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
 def Vector_IndexTupleOp :
   Vector_Op<"make_index_tuple", [NoSideEffect]>,
     Arguments<(ins Variadic:$operands)>,
diff --git a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 6086531e3c7..7f3be9d9fa9 100644
--- a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -995,6 +995,37 @@ static LogicalResult verify(TypeCastOp &op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// CreateMaskOp
+//===----------------------------------------------------------------------===//
+
+ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
+  auto indexType = parser.getBuilder().getIndexType();
+  Type resultType;
+  SmallVector operandInfo;
+  return failure(
+      parser.parseOperandList(operandInfo) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(resultType) ||
+      parser.resolveOperands(operandInfo, indexType, result.operands) ||
+      parser.addTypeToList(resultType, result.types));
+}
+
+static void print(OpAsmPrinter &p, CreateMaskOp &op) {
+  p << op.getOperationName() << ' ';
+  p.printOperands(op.operands());
+  p << " : " << op.getResult()->getType();
+}
+
+static LogicalResult verify(CreateMaskOp &op) {
+  // Verify that an operand was specified for each result vector each dimension.
+  if (op.getNumOperands() !=
+      op.getResult()->getType().cast().getRank())
+    return op.emitOpError(
+        "must specify an operand for each result vector dimension");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // IndexTupleOp
 //===----------------------------------------------------------------------===//

From 97b36b0d4a9389178e510b2f59f98c0c75e0b882 Mon Sep 17 00:00:00 2001
From: Stella Laurenzo 
Date: Tue, 3 Dec 2019 12:04:11 -0800
Subject: [PATCH 231/279] Updated bug link now that it is filed properly
 upstream.

PiperOrigin-RevId: 283594085
Change-Id: If02fbabc9e6f91c196185cf5aa312c21b7f8e924
---
 third_party/llvm/llvm.bzl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl
index 5a478827980..0d06b7e8df7 100644
--- a/third_party/llvm/llvm.bzl
+++ b/third_party/llvm/llvm.bzl
@@ -294,7 +294,7 @@ win32_cmake_vars = {
 
     # ThreadPoolExecutor global destructor and thread handshaking do not work
     # on this platform when used as a DLL.
-    # See: https://github.com/google/iree/issues/114
+    # See: https://bugs.llvm.org/show_bug.cgi?id=44211
     "LLVM_ENABLE_THREADS": 0,
 }
 

From 08266e2d237fc78511393ee52bec667a83891c53 Mon Sep 17 00:00:00 2001
From: Berkin Ilbeyi 
Date: Tue, 3 Dec 2019 12:14:26 -0800
Subject: [PATCH 232/279] [XLA] Respect the alternate memory space in layout of
 inputs and outputs.

When the inputs/outputs are pinned to the alternate memory space (e.g. using
go/tpu-fast-mem-inference), ensure the memory space assignment respects the
memory space by ensuring:

  1- Memory space assignment no longer allocates these buffers, assuming
     BufferAssignment will allocate them.
  2- Memory space assignment now assumes these inputs/outputs will live in the
     alternate memory for the entire duration of the computation, hence
     accounting for the reduction in available space in the alternate memory
     space, fixing out-of-memory errors. This also removes redundant alternate
     mem to alternate mem asynchronous copies.

PiperOrigin-RevId: 283596052
Change-Id: Ibcb0a3555960c66bb2ff8f62d3b36a4d35628d9d
---
 .../xla/service/memory_space_assignment.cc    | 100 +++++++++++++++---
 .../xla/service/memory_space_assignment.h     |  13 +++
 .../service/memory_space_assignment_test.cc   |  68 ++++++++++++
 3 files changed, 165 insertions(+), 16 deletions(-)

diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 28c93fb75fd..c1dc635eb81 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -244,6 +244,32 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
 
     auto colocated_intervals = GetSortedColocatedIntervals(interval);
 
+    if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
+      VLOG(4) << "Interval " << interval.buffer->ToShortString()
+              << " is reserved in the alternate memory. Total reserved bytes = "
+              << reserved_in_bytes_;
+      for (const BufferInterval* colocated_interval : colocated_intervals) {
+        const HloValue* value = colocated_interval->buffer;
+        // Color all of the aliased reserved buffers here because reserved
+        // alternate memory allocations will not have an entry in preset
+        // allocations that is normally used for coloring.
+        for (auto& position : value->positions()) {
+          VLOG(3) << "Coloring " << position.ToString();
+          Shape* shape = ShapeUtil::GetMutableSubshape(
+              position.instruction->mutable_shape(), position.index);
+          CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
+                                  << position.ToString();
+          shape->mutable_layout()->set_memory_space(
+              options_.alternate_memory_space);
+        }
+      }
+      // Increment the reserved part of alternate memory so that it is not
+      // available for other buffers. Since all colocated intervals should have
+      // the same size, just use the first one.
+      reserved_in_bytes_ += options_.size_fn(*colocated_intervals[0]->buffer);
+      continue;
+    }
+
     if (colocated_intervals.size() > 1 &&
         !options_.allocate_across_sequential_calls) {
       VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
@@ -366,10 +392,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
 }
 
 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
-  // Go through the parameters and outputs and pin them to default memory by
-  // adding a required assignment.
-  // TODO(berkin): If these values are already marked alternate memory, use
-  // those instead.
+  // Go through the parameters and outputs and pin them to the corresponding
+  // memory by adding a required assignment.
   const HloModule& module = alias_analysis_.dataflow_analysis().module();
   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
   HloComputation* entry_computation = module.entry_computation();
@@ -379,16 +403,22 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
         instruction_schedule.at(parameter_instruction);
     ShapeUtil::ForEachSubshape(
         parameter_instruction->shape(),
-        [&](const Shape& /*subshape*/, const ShapeIndex& index) {
+        [&](const Shape& subshape, const ShapeIndex& index) {
+          MemorySpace memory_space = MemorySpace::kDefault;
+          if (subshape.has_layout() && subshape.layout().memory_space() ==
+                                           options_.alternate_memory_space) {
+            memory_space = MemorySpace::kAlternate;
+          }
           for (const HloBuffer* buffer :
                alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
             for (const HloValue* value : buffer->values()) {
               VLOG(3) << "Adding required assignment for parameter value = "
                       << value->ToShortString()
-                      << " time = " << parameter_instruction_time;
+                      << " time = " << parameter_instruction_time << " space = "
+                      << (memory_space == MemorySpace::kDefault ? "def"
+                                                                : "alt");
               required_assignments_[value].push_back(
-                  {/*memory_space=*/MemorySpace::kDefault,
-                   /*time=*/parameter_instruction_time});
+                  {memory_space, /*time=*/parameter_instruction_time});
             }
           }
         });
@@ -397,21 +427,56 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
   int64 root_instruction_time = instruction_schedule.at(root_instruction);
   ShapeUtil::ForEachSubshape(
       root_instruction->shape(),
-      [&](const Shape& /*subshape*/, const ShapeIndex& index) {
+      [&](const Shape& subshape, const ShapeIndex& index) {
+        MemorySpace memory_space = MemorySpace::kDefault;
+        if (subshape.has_layout() && subshape.layout().memory_space() ==
+                                         options_.alternate_memory_space) {
+          memory_space = MemorySpace::kAlternate;
+        }
         for (const HloBuffer* buffer :
              alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
           for (const HloValue* value : buffer->values()) {
             VLOG(3) << "Adding required assignment for output value = "
                     << value->ToShortString()
-                    << " time = " << root_instruction_time;
+                    << " time = " << root_instruction_time << " space = "
+                    << (memory_space == MemorySpace::kDefault ? "def" : "alt");
             required_assignments_[value].push_back(
-                {/*memory_space=*/MemorySpace::kDefault,
-                 /*time=*/root_instruction_time});
+                {memory_space, /*time=*/root_instruction_time});
           }
         }
       });
 }
 
+bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
+    absl::Span colocated_intervals) const {
+  auto is_position_in_alternate_memory = [&](const HloPosition& position) {
+    const Shape& shape = position.shape();
+    return shape.has_layout() &&
+           shape.layout().memory_space() == options_.alternate_memory_space;
+  };
+
+  const HloModule& module = alias_analysis_.dataflow_analysis().module();
+  const HloComputation* entry_computation = module.entry_computation();
+  const HloInstruction* root_instruction =
+      entry_computation->root_instruction();
+  for (const BufferInterval* colocated_interval : colocated_intervals) {
+    const HloValue* value = colocated_interval->buffer;
+    if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
+        value->defining_instruction()->parent() == entry_computation &&
+        is_position_in_alternate_memory(value->defining_position())) {
+      return true;
+    }
+
+    for (const HloPosition& position : value->positions()) {
+      if (position.instruction == root_instruction &&
+          is_position_in_alternate_memory(position)) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
 void AlternateMemoryBestFitHeap::CommitPendingChunks() {
   for (auto interval_and_chunk : pending_chunks_) {
     VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
@@ -482,8 +547,11 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
   if (required_assignment_it != required_assignments_.end()) {
     for (const RequiredMemoryAssignment& required_assignment :
          required_assignment_it->second) {
-      VLOG(3) << "Required assignment at time = " << required_assignment.time;
-      // TODO(berkin): Handle memory requirements for alternate memory space.
+      VLOG(3) << "Required assignment at time = " << required_assignment.time
+              << " space = "
+              << (required_assignment.memory_space == MemorySpace::kDefault
+                      ? "def"
+                      : "alt");
       if (required_assignment.memory_space == MemorySpace::kDefault) {
         if (required_assignment.time == start_time) {
           definition_requires_buffer_in_default_mem = true;
@@ -613,7 +681,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
     }
     ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval);
     // Check if the new heap size fits within limits.
-    if (chunk_candidate.heap_size < options_.max_size_in_bytes) {
+    if (chunk_candidate.heap_size < available_heap_size()) {
       VLOG(3) << "Move the buffer to alternate memory at "
               << alternate_mem_interval.start
               << ". Offset = " << chunk_candidate.chunk.offset
@@ -748,7 +816,7 @@ bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy(
   alternate_mem_interval.end = end_time;
   // Check if the new heap size fits within limits. Also ensure if a
   // preferred offset was provided, that offset was used.
-  if (chunk_candidate.heap_size <= options_.max_size_in_bytes &&
+  if (chunk_candidate.heap_size <= available_heap_size() &&
       (preferred_offset == -1 ||
        preferred_offset == chunk_candidate.chunk.offset)) {
     VLOG(3) << "Keep the buffer in alternate memory. Offset = "
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index a8b3310cf24..20551feb715 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -547,6 +547,12 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
   // Adds input and outputs as required assignments.
   void AddInputAndOutputRequiredAssignments();
 
+  // Returns true if the colocated intervals in the argument are in a parameter
+  // or root instruction of the entry computation and are reserved by the user
+  // to be in the alternate memory space.
+  bool AreIntervalsReservedInAlternateMemory(
+      absl::Span colocated_intervals) const;
+
   // Given a buffer interval, returns the colocated intervals. Unlike the
   // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
   // returns the colocated intervals sorted by scheduled time.
@@ -575,6 +581,11 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
                           const ChunkCandidate& chunk_candidate);
   void CommitPendingChunks();
 
+  // Returns the available heap size in the alternate memory.
+  int64 available_heap_size() const {
+    return options_.max_size_in_bytes - reserved_in_bytes_;
+  }
+
   MemorySpaceAssignment::AllocationMap* allocation_map_;
   const MemorySpaceAssignment::Options& options_;
   const HloAliasAnalysis& alias_analysis_;
@@ -588,6 +599,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
   // and outputs).
   absl::flat_hash_map>
       required_assignments_;
+  // Number of bytes reserved in alternate memory space.
+  int64 reserved_in_bytes_ = 0;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 7e3ce7dfbbd..068834e5701 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -2125,6 +2125,74 @@ TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) {
   }
 }
 
+TEST_P(MemorySpaceAssignmentTest,
+       InputOutputsInAlternateMemShouldntBeAssigned) {
+  // When input/outputs are marked to be in the alternate memory (e.g.
+  // go/tpu-fast-mem-inference), do not allocate those and assume they will live
+  // in the alternate memory for the entire computation. The BufferAssignment
+  // pass, which is run after this, will allocate those buffers.
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+  Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
+      F32, {2, 3},
+      /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
+      kAlternateMemorySpace);
+  // p0 is in the default memory space.
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  // p1 is in the alternate memory space.
+  HloInstruction* p1 = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, shape_in_alternate_mem, "p1"));
+  HloInstruction* negate0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
+  HloInstruction* negate1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* negate2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* negate3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* negate4 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
+  HloInstruction* negate5 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
+  HloInstruction* negate6 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
+  HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
+      shape_in_alternate_mem, HloOpcode::kAdd, negate6, p1));
+  // Index {0} of the root instruction is in the alternate memory space, index
+  // {1} is in the default memory space.
+  HloInstruction* tuple =
+      builder.AddInstruction(HloInstruction::CreateTuple({add, negate5}));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(computation,
+                        {p0, p1, negate0, negate1, negate2, negate3, negate4,
+                         negate5, negate6, add, tuple});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  std::unique_ptr preset_assignments =
+      AssignMemorySpace(module.get());
+
+  // Ensure that p1 is in the alternate memory and add, which has p1 as an
+  // operand, has a direct dependency to p1 (no CopyStart/CopyDone).
+  EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem));
+  EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1)));
+  // Make sure add is still in the alternate memory space.
+  EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
+
+  // Check the preset assignments and ensure the inputs/outputs in the alternate
+  // memory space aren't in the preset assignments. Inputs/outputs in the
+  // alternate memory space are left to BufferAssignment to be allocated.
+  for (const auto& position_and_chunk : preset_assignments->chunks()) {
+    const HloPosition& position = position_and_chunk.first;
+    EXPECT_NE(position.instruction, p1);
+    EXPECT_NE(position.instruction, add);
+  }
+}
+
 INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
                          MemorySpaceAssignmentTest,
                          ::testing::Values(false, true));

From 0c3af9326f23cb4a254c41a80679ab39f87e82d0 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 13:03:44 -0800
Subject: [PATCH 233/279] Fix possible segfault with ShapeHandle.

PiperOrigin-RevId: 283605264
Change-Id: I4531583f987439338f3207697f388fd1c74ccf23
---
 tensorflow/core/framework/shape_inference.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index aa0a6247312..b11df7e5d8a 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -306,7 +306,7 @@ class InferenceContext {
   // idx can be negative for an offset from end of dimensions.
   // idx must be in the range [-1 * s.rank, s.rank).
   DimensionHandle Dim(ShapeHandle s, int64 idx) {
-    if (s->rank_ == kUnknownRank) {
+    if (!s.Handle() || s->rank_ == kUnknownRank) {
       return UnknownDim();
     }
     return DimKnownRank(s, idx);

From 1e7a91e26abd93086a376ef6212bdf463c747dca Mon Sep 17 00:00:00 2001
From: Sean Silva 
Date: Tue, 3 Dec 2019 13:23:40 -0800
Subject: [PATCH 234/279] Don't canonicalize away casts between different
 types.

The previous code would incorrectly only check the element type, rather than exact type equality. Failure to do so can trigger many different kinds of verifier errors.

PiperOrigin-RevId: 283609199
Change-Id: I3bbd8b41a6a2c8edd2e9d97b32eda78e546975ac
---
 .../compiler/mlir/tensorflow/tests/canonicalize.mlir  | 11 +++++++++++
 .../mlir/tensorflow/transforms/canonicalize.td        |  6 ++++--
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index 2db64262094..a2cc33a8201 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -101,6 +101,17 @@ func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64
 // CHECK: return %0, %1
 }
 
+// CHECK-LABEL: testCompatibleCastType
+func @testCompatibleCastType(%arg0: tensor) -> (tensor<10xf32>, tensor<10xf32>) {
+  %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor) -> tensor<10xf32>
+  %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor) -> tensor<10xf32>
+  return %0, %1: tensor<10xf32>, tensor<10xf32>
+
+// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor) -> tensor<10xf32>
+// CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor) -> tensor<10xf32>
+// CHECK: return %0, %1
+}
+
 // CHECK-LABEL: testSameCastTypeAcrossBasicBlocks
 func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> {
 ^bb0(%arg0: tensor<8x16x32x64xf32>):
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
index beb7583fc57..7c38b78f239 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
@@ -22,6 +22,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 def SingleResultAndOperandHaveSameElementType : Constraint<
   CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
 
+def SingleResultAndOperandHaveSameType : Constraint<
+  CPred<"$0->getType() == $1->getType()">>;
+
 def IsRank2Tensor : Type, "Rank 2 tensor">;
 
 //===----------------------------------------------------------------------===//
@@ -75,8 +78,7 @@ def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
 
 def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate),
                        (replaceWithValue $arg),
-                       [(SingleResultAndOperandHaveSameElementType $res,
-                                                                   $arg)]>;
+                       [(SingleResultAndOperandHaveSameType $res, $arg)]>;
 
 //===----------------------------------------------------------------------===//
 // Conj op patterns.

From 3c28370a9c66f12f04d6a595d5e46eaf5c460d1f Mon Sep 17 00:00:00 2001
From: HyoukJoong Lee 
Date: Tue, 3 Dec 2019 13:28:24 -0800
Subject: [PATCH 235/279] Combine cross-replica / cross-partition AllReduce
 after SPMD partition

PiperOrigin-RevId: 283610192
Change-Id: I801097d159c39d8137457c55906d455e0ee7733d
---
 tensorflow/compiler/xla/service/BUILD         |   1 +
 .../xla/service/all_reduce_simplifier.cc      |   4 +-
 .../compiler/xla/service/ar_crs_combiner.cc   |  52 +-
 .../compiler/xla/service/ar_crs_combiner.h    |  37 +-
 .../xla/service/ar_crs_combiner_test.cc       | 455 +++++++++++++++++-
 .../xla/service/hlo_replication_analysis.cc   | 103 ++--
 .../xla/service/hlo_replication_analysis.h    |  34 +-
 .../service/hlo_replication_analysis_test.cc  | 131 ++++-
 8 files changed, 736 insertions(+), 81 deletions(-)

diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 23d203850fc..a6300d2dc73 100755
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -4197,6 +4197,7 @@ cc_library(
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla/service:hlo_replication_analysis",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
     ],
diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc
index b3097b8ff77..541006f04d5 100644
--- a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc
+++ b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc
@@ -28,7 +28,9 @@ limitations under the License.
 namespace xla {
 
 StatusOr AllReduceSimplifier::Run(HloModule* module) {
-  TF_ASSIGN_OR_RETURN(auto replication, HloReplicationAnalysis::Run(module));
+  TF_ASSIGN_OR_RETURN(
+      auto replication,
+      HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false));
   std::vector all_reduces_to_replace;
   for (auto computation : module->computations()) {
     for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc
index ae39906ef52..06aaad351e6 100644
--- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc
+++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc
@@ -25,6 +25,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -240,7 +241,8 @@ bool ArCrsCombiner::TupleElementsComputeSameValue(
 /* static */
 bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
                                                      HloInstruction* i2) {
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
+                         /*spmd_partition=*/false);
   auto module = i1->parent()->parent();
   CHECK_EQ(module, i2->parent()->parent());
   combiner.call_graph_ = CallGraph::Build(module);
@@ -363,14 +365,14 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
   }
 }
 
-void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
+Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() {
   for (auto it : all_reduce_map_) {
     auto channel_id = it.first;
     VLOG(2)
         << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
         << channel_id << "\n";
     auto pairs_vec = it.second;
-    CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
+    TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_);
     auto instr_0 = pairs_vec[0].ar;
     for (int i = 1; i < pairs_vec.size(); ++i) {
       auto instr_i = pairs_vec[i].ar;
@@ -393,6 +395,44 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
       }
     }
   }
+  return Status::OK();
+}
+
+Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD(
+    HloModule* module) {
+  // For SPMD mode, use HloReplicationAnalysis to figure out HLO value
+  // equivalence across partitions.
+  TF_ASSIGN_OR_RETURN(
+      auto replication_analysis,
+      HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
+
+  for (auto it : all_reduce_map_) {
+    auto channel_id = it.first;
+    VLOG(2)
+        << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
+        << channel_id << "\n";
+    auto pairs_vec = it.second;
+    TF_RET_CHECK(pairs_vec.size() == 1);
+    auto instr = pairs_vec[0].ar;
+    auto next = instr->users()[0];
+    while (true) {
+      // The patterns we detect in ArCrsCombiner::MatchesArCrsPattern()
+      // guarantee that the HLO produces an array.
+      TF_RET_CHECK(next->shape().IsArray());
+      if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) {
+        all_reduce_map_.erase(channel_id);
+        VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
+                   "channel id: "
+                << channel_id << "\n";
+        break;
+      }
+      if (next->IsCrossReplicaAllReduce()) {
+        break;
+      }
+      next = next->users()[0];
+    }
+  }
+  return Status::OK();
 }
 
 StatusOr ArCrsCombiner::RewriteGraph() {
@@ -460,7 +500,11 @@ StatusOr ArCrsCombiner::Run(HloModule* module) {
 
   GroupAllReducesById(module);
 
-  KeepProvablyEqualInstructionGroups();
+  if (spmd_partition_) {
+    TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module));
+  } else {
+    TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD());
+  }
 
   return RewriteGraph();
 }
diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h
index a85e18d328c..95443c0c74a 100644
--- a/tensorflow/compiler/xla/service/ar_crs_combiner.h
+++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h
@@ -25,18 +25,21 @@ limitations under the License.
 
 namespace xla {
 
-// When the HLO graph contains a cross-module AllReduce, followed by some simple
-// linear operations, followed by a cross-replica AllReduce (also known as
-// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an
-// efficient AllReduce implementation that fully utilizes the interconnect
-// bandwidth.
-// Such sequences appear in spatially partitioned models.
+// When the HLO graph contains a cross-module AllReduce (N separate AllReduce
+// ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op
+// for SPMD partitioning), followed by some simple linear operations, followed
+// by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we
+// can combine the CMAR and the CRAR, to use an efficient AllReduce
+// implementation that fully utilizes the interconnect bandwidth.
+//
+// Such sequences appear in spatially partitioned models (either MPMD or SPMD).
 // This pass must run right after spatial partitioning, when the code is still
 // in a single HLO module.
 //
 // The steps are:
 // 1) Find CMARs followed by simple ops followed by CRARs.
-// 2) Group CMARs by channel_id. They must all be rewritten.
+// 2) Group CMARs by channel_id. They must all be rewritten. For SPMD
+//    partitioning, there will only be a single CMAR for each channel_id.
 // 3) Prove that the CMAR patterns in each core produce the same result.
 // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
 //    other operand by the number of spatial partitions.
@@ -69,9 +72,11 @@ namespace xla {
 //
 class ArCrsCombiner : public HloModulePass {
  public:
-  ArCrsCombiner(int num_spatial_partitions, int num_replicas)
+  ArCrsCombiner(int num_spatial_partitions, int num_replicas,
+                bool spmd_partition)
       : num_spatial_partitions_(num_spatial_partitions),
-        num_replicas_(num_replicas) {}
+        num_replicas_(num_replicas),
+        spmd_partition_(spmd_partition) {}
   absl::string_view name() const override { return "ar-crs-combiner"; }
   StatusOr Run(HloModule* module) override;
 
@@ -153,7 +158,10 @@ class ArCrsCombiner : public HloModulePass {
 
   // Looks at each AllReduce group in all_reduce_map_, and keeps only the
   // groups for which it's safe to move the AllReduce later in the HLO graph.
-  void KeepProvablyEqualInstructionGroups();
+  Status KeepProvablyEqualInstructionGroupsMPMD();
+
+  // Same as above, but runs on SPMD partitioned module instead of MPMD.
+  Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module);
 
   // Performs the graph rewrite that eliminates the early AllReduce and turns
   // the later CRS into an AllReduce.
@@ -163,6 +171,15 @@ class ArCrsCombiner : public HloModulePass {
 
   int num_replicas_;
 
+  // Run this combiner pass assuming the input module is an SPMD partitioned
+  // module (as opposed to MPMD partitioned).
+  //
+  // The main difference between the two w.r.t. this pass is that there would be
+  // N all-reduce ops for each channel in MPMD mode, whereas there is only 1
+  // for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO
+  // equivalence check in SPMD mode.
+  bool spmd_partition_;
+
   // Map from all-reduce ids to the AR/CRS pairs.
   absl::flat_hash_map> all_reduce_map_;
 
diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc
index accc0684e8e..609da2c33a0 100644
--- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc
+++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc
@@ -452,7 +452,8 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(module->entry_computation()->root_instruction(),
@@ -464,6 +465,55 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
   CompareReplicaGroups(replica_groups_before, replica_groups_after);
 }
 
+TEST_F(ArCrsCombinerTest, RewriteArConvertCrsSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
+  %a = bf16[] parameter(0)
+  %b = bf16[] parameter(1)
+  ROOT %add = bf16[] add(%a, %b)
+}
+
+%sum.f32 (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
+  %p = bf16[] parameter(0)
+  %all-reduce.ar.1 = bf16[]
+      all-reduce(%p),
+      replica_groups={{0},{1}},
+      channel_id=1,
+      to_apply=%sum.bf16
+  %convert.1 = f32[] convert(%all-reduce.ar.1)
+  %all-reduce.1 = f32[]
+      all-reduce(%convert.1),
+      replica_groups={{0,1}},
+      to_apply=%sum.f32
+  ROOT %tuple = (f32[]) tuple(%all-reduce.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  auto crs_before =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_before = crs_before->replica_groups();
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         true);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              op::Tuple(op::AllReduce(op::Convert(op::Parameter()))));
+  auto crs_after =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_after = crs_after->replica_groups();
+  CompareReplicaGroups(replica_groups_before, replica_groups_after);
+}
+
 TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
   const char* module_str = R"(
 HloModule foobar
@@ -520,7 +570,8 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(module->entry_computation()->root_instruction(),
@@ -587,7 +638,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(
@@ -600,6 +652,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   CompareReplicaGroups(replica_groups_before, replica_groups_after);
 }
 
+TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrsSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum.f32 (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: f32[]) -> (f32[]) {
+  %p = f32[] parameter(0)
+  %constant.f32 = f32[] constant(123)
+
+  %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
+      channel_id=1, to_apply=%sum.f32
+  %multiply.1 = f32[] multiply(%all-reduce.ar.1, %constant.f32)
+  %all-reduce.1 = f32[] all-reduce(%multiply.1), replica_groups={{0,1}},
+      to_apply=%sum.f32, sharding={maximal device=0}
+  ROOT %tuple = (f32[]) tuple(%all-reduce.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  auto crs_before =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_before = crs_before->replica_groups();
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/true);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
+  auto crs_after =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_after = crs_after->replica_groups();
+  CompareReplicaGroups(replica_groups_before, replica_groups_after);
+}
+
 TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
   const char* module_str = R"(
 HloModule foobar
@@ -668,7 +761,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(
@@ -684,6 +778,55 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   CompareReplicaGroups(replica_groups_before, replica_groups_after);
 }
 
+TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrsSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
+  %a = bf16[] parameter(0)
+  %b = bf16[] parameter(1)
+  ROOT %add = bf16[] add(%a, %b)
+}
+
+%sum.f32 (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: f32[]) -> (f32[]) {
+  %p = f32[] parameter(0)
+  %constant.bf16 = bf16[] constant(1)
+  %constant.f32 = f32[] constant(2)
+
+  %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
+      channel_id=1, to_apply=%sum.bf16
+  %convert.1 = f32[] convert(%all-reduce.ar.1), sharding={maximal device=0}
+  %add.1 = f32[] add(%constant.f32, %convert.1)
+  %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}},
+      to_apply=%sum.f32
+  ROOT %tuple = (f32[]) tuple(%all-reduce.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  auto crs_before =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_before = crs_before->replica_groups();
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/true);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              op::Tuple(op::AllReduce(op::Add(
+                  op::Divide(op::Constant(), op::Constant()), op::Convert()))));
+  auto crs_after =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_after = crs_after->replica_groups();
+  CompareReplicaGroups(replica_groups_before, replica_groups_after);
+}
+
 TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
   const char* module_str = R"(
 HloModule foobar
@@ -750,7 +893,46 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
 
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
                           ParseAndReturnVerifiedModule(module_str));
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewriteSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
+  %a = bf16[] parameter(0)
+  %b = bf16[] parameter(1)
+  ROOT %add = bf16[] add(%a, %b)
+}
+
+%sum.f32 (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: f32[]) -> (f32[]) {
+  %p = f32[] parameter(0)
+  %constant.bf16 = bf16[] constant(1)
+  %constant.f32.1 = f32[] constant(2)
+
+  %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
+      channel_id=1, to_apply=%sum.bf16
+  %convert.1 = f32[] convert(%all-reduce.ar.1)
+  %add.1 = f32[] add(%p, %convert.1)
+  %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32
+  ROOT %tuple = (f32[]) tuple(%all-reduce.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/true);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_FALSE(changed);
 }
@@ -810,7 +992,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(module->entry_computation()->root_instruction(),
@@ -884,7 +1067,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(module->entry_computation()->root_instruction(),
@@ -902,6 +1086,50 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   CompareReplicaGroups(replica_groups_before, replica_groups_after);
 }
 
+TEST_F(ArCrsCombinerTest, RewriteMultipleAddsSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: f32[]) -> (f32[]) {
+  %p = f32[] parameter(0)
+  %constant.1 = f32[] constant(1)
+  %constant.2 = f32[] constant(2)
+
+  %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
+      channel_id=1, to_apply=%sum
+  %add.11 = f32[] add(%constant.1, %all-reduce.ar.1)
+  %add.12 = f32[] add(%constant.2, %add.11)
+  %all-reduce.1 = f32[] all-reduce(%add.12), replica_groups={{0,1}}, to_apply=%sum
+  ROOT %tuple = (f32[]) tuple(%all-reduce.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  auto crs_before =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_before = crs_before->replica_groups();
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/true);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              op::Tuple(op::AllReduce(
+                  op::Add(op::Divide(op::Constant(), op::Constant()),
+                          op::Add(op::Divide(op::Constant(), op::Constant()),
+                                  op::Parameter())))));
+  auto crs_after =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_after = crs_after->replica_groups();
+  CompareReplicaGroups(replica_groups_before, replica_groups_after);
+}
+
 TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) {
   const char* module_str = R"(
 HloModule foobar
@@ -957,7 +1185,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(
@@ -973,6 +1202,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   CompareReplicaGroups(replica_groups_before, replica_groups_after);
 }
 
+TEST_F(ArCrsCombinerTest, RewriteArSubtractCrsSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum.f32 (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: f32[]) -> (f32[]) {
+  %p = f32[] parameter(0)
+  %constant.f32 = f32[] constant(123)
+  %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
+      channel_id=1, to_apply=%sum.f32
+  %sub.1 = f32[] subtract(%constant.f32, %all-reduce.ar.1)
+  %all-reduce.1 = f32[] all-reduce(%sub.1), replica_groups={{0,1}},
+      to_apply=%sum.f32
+  ROOT %tuple = (f32[]) tuple(%all-reduce.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  auto crs_before =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_before = crs_before->replica_groups();
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/true);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      op::Tuple(op::AllReduce(op::Subtract(
+          op::Divide(op::Constant(), op::Constant()), op::Parameter()))));
+  auto crs_after =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_after = crs_after->replica_groups();
+  CompareReplicaGroups(replica_groups_before, replica_groups_after);
+}
+
 TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) {
   const char* module_str = R"(
 HloModule foobar
@@ -1047,7 +1317,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(module->entry_computation()->root_instruction(),
@@ -1065,6 +1336,53 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   CompareReplicaGroups(replica_groups_before, replica_groups_after);
 }
 
+TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeftSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: f32[]) -> (f32[]) {
+  %p = f32[] parameter(0)
+  %const1 = f32[] constant(1)
+  %const2 = f32[] constant(2)
+
+  %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1,
+      to_apply=%sum
+  %add11 = f32[] add(%ar11, %const1)
+  %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2,
+      to_apply=%sum
+  %add12 = f32[] add(%add11, %ar12)
+  %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}},
+      to_apply=%sum
+  ROOT %tuple = (f32[]) tuple(%crs1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  auto crs_before =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_before = crs_before->replica_groups();
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/true);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      op::Tuple(op::AllReduce(op::Add(
+          op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())),
+          op::Parameter()))));
+  auto crs_after =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_after = crs_after->replica_groups();
+  CompareReplicaGroups(replica_groups_before, replica_groups_after);
+}
+
 TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) {
   const char* module_str = R"(
 HloModule foobar
@@ -1139,7 +1457,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   auto crs_before =
       module->entry_computation()->root_instruction()->operands()[0];
   auto replica_groups_before = crs_before->replica_groups();
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_TRUE(changed);
   EXPECT_THAT(
@@ -1159,6 +1478,51 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
   CompareReplicaGroups(replica_groups_before, replica_groups_after);
 }
 
+TEST_F(ArCrsCombinerTest, RewriteMultipleARsRightSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: f32[]) -> (f32[]) {
+  %p = f32[] parameter(0)
+  %const1 = f32[] constant(1)
+  %const2 = f32[] constant(2)
+
+  %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, to_apply=%sum
+  %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, to_apply=%sum
+  %add11 = f32[] add(%ar12, %const1)
+  %add12 = f32[] add(%ar11, %add11)
+  %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, to_apply=%sum
+  ROOT %tuple = (f32[]) tuple(%crs1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  auto crs_before =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_before = crs_before->replica_groups();
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/true);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              op::Tuple(op::AllReduce(op::Add(
+                  op::Parameter(),
+                  op::Add(op::Parameter(),
+                          op::Divide(op::Constant(), op::Constant()))))));
+
+  auto crs_after =
+      module->entry_computation()->root_instruction()->operands()[0];
+  auto replica_groups_after = crs_after->replica_groups();
+  CompareReplicaGroups(replica_groups_before, replica_groups_after);
+}
+
 TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
   const char* module_str = R"(
 HloModule foobar
@@ -1217,7 +1581,45 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
 
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
                           ParseAndReturnVerifiedModule(module_str));
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
+                         /*spmd_partition=*/false);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(ArCrsCombinerTest, OneReplicaDontRewriteSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
+  %a = bf16[] parameter(0)
+  %b = bf16[] parameter(1)
+  ROOT %add = bf16[] add(%a, %b)
+}
+
+%sum.f32 (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
+  %p = bf16[] parameter(0)
+  %constant.bf16 = bf16[] constant(1)
+
+  %all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0}},
+      channel_id=1, to_apply=%sum.bf16
+  %convert.1 = f32[] convert(%all-reduce.ar.1)
+  %all-reduce.1 = f32[] all-reduce(%convert.1),
+      replica_groups={{0}}, to_apply=%sum.f32
+  ROOT %tuple = (f32[]) tuple(%all-reduce.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
+                         /*spmd_partition=*/true);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_FALSE(changed);
 }
@@ -1291,7 +1693,36 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
 
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
                           ParseAndReturnVerifiedModule(module_str));
-  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/false);
+  auto changed = combiner.Run(module.get()).ValueOrDie();
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(ArCrsCombinerTest, AllReduceWithReplicasSPMD) {
+  const char* module_str = R"(
+HloModule foobar
+
+%sum.f32 (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
+  %p = bf16[] parameter(0)
+  %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
+    to_apply=%sum.f32
+  %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
+    to_apply=%sum.f32
+  ROOT %tuple = (f32[]) tuple(%all-reduce.2)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module,
+                          ParseAndReturnVerifiedModule(module_str));
+  ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
+                         /*spmd_partition=*/true);
   auto changed = combiner.Run(module.get()).ValueOrDie();
   EXPECT_FALSE(changed);
 }
diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc
index e11d3920f95..3a896d4a113 100644
--- a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc
@@ -35,13 +35,45 @@ namespace {
 // knowledge in hlo_replication.
 bool DetermineHloInstructionIsReplicated(
     const HloInstruction* hlo, const ShapeIndex& index,
+    bool cross_partition_spmd,
     const absl::flat_hash_map>&
         hlo_replication) {
+  // Returns true if all operands are known to be replicated.
+  const auto all_operands_replicated =
+      [&hlo_replication](const HloInstruction* inst) {
+        for (auto operand : inst->operands()) {
+          auto operand_it = hlo_replication.find(operand);
+          if (operand_it == hlo_replication.end() ||
+              !operand_it->second.element({})) {
+            return false;
+          }
+        }
+        return true;
+      };
+
+  if (hlo->IsCrossReplicaAllReduce()) {
+    if (cross_partition_spmd) {
+      // Cross-replica all-reduce returns same values across partitions as long
+      // as its operands are replicated.
+      return all_operands_replicated(hlo);
+    }
+    // Only all-reduce across all cores are replicated, which means there
+    // is only one subgroup.
+    return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
+  }
+  if (hlo->IsCrossModuleAllReduce()) {
+    return cross_partition_spmd;
+  }
   if (hlo->HasSideEffectNoRecurse()) {
     return false;
   }
   if (hlo->opcode() == HloOpcode::kReplicaId) {
-    return false;
+    // ReplicaId returns the same value for all partitions in each replica.
+    return cross_partition_spmd;
+  }
+  if (hlo->opcode() == HloOpcode::kPartitionId) {
+    // PartitionId returns the same value for all replicas in each partition.
+    return !cross_partition_spmd;
   }
   auto it = hlo_replication.find(hlo);
   if (hlo->opcode() == HloOpcode::kParameter) {
@@ -55,11 +87,6 @@ bool DetermineHloInstructionIsReplicated(
   if (hlo->opcode() == HloOpcode::kConstant) {
     return true;
   }
-  if (hlo->opcode() == HloOpcode::kAllReduce) {
-    // Only all-reduce across all cores are replicated, which means there
-    // is only one subgroup.
-    return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
-  }
 
   if (hlo->IsElementwise() ||                             //
       hlo->opcode() == HloOpcode::kConcatenate ||         //
@@ -80,14 +107,7 @@ bool DetermineHloInstructionIsReplicated(
       hlo->opcode() == HloOpcode::kDynamicUpdateSlice ||  //
       hlo->opcode() == HloOpcode::kReduceWindow ||        //
       hlo->opcode() == HloOpcode::kCopy) {
-    for (auto operand : hlo->operands()) {
-      auto operand_it = hlo_replication.find(operand);
-      if (operand_it == hlo_replication.end() ||
-          !operand_it->second.element({})) {
-        return false;
-      }
-    }
-    return true;
+    return all_operands_replicated(hlo);
   }
   return false;
 }
@@ -235,8 +255,8 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
         ShapeUtil::ForEachSubshape(
             inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
               *shape_tree.mutable_element(index) =
-                  DetermineHloInstructionIsReplicated(inst, index,
-                                                      hlo_replication_);
+                  DetermineHloInstructionIsReplicated(
+                      inst, index, cross_partition_spmd_, hlo_replication_);
               return Status::OK();
             });
         changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
@@ -248,23 +268,39 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
 
 void HloReplicationAnalysis::ComputeHloReplication() {
   // Add entry parameters to the above sets according to user annotation.
+  // Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas
+  // SPMD partitioned modules read from HloSharding attributes.
   auto entry = module_->entry_computation();
   for (int i = 0; i < entry->num_parameters(); ++i) {
     auto param = entry->parameter_instruction(i);
     ShapeTree shape_tree(param->shape(), false);
-    const auto& replication = param->parameter_replicated_at_leaf_buffers();
-    int leaf_index = 0;
-    ShapeUtil::ForEachSubshape(
-        param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
-          if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
+    if (cross_partition_spmd_ && param->has_sharding()) {
+      auto sharding_tree =
+          param->sharding().AsShapeTree(param->shape()).ValueOrDie();
+      ShapeUtil::ForEachSubshape(
+          param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
+            if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
+              return Status::OK();
+            }
+            *shape_tree.mutable_element(index) =
+                sharding_tree.element(index).IsReplicated();
             return Status::OK();
-          }
-          if (replication && replication->at(leaf_index)) {
-            *shape_tree.mutable_element(index) = true;
-          }
-          ++leaf_index;
-          return Status::OK();
-        });
+          });
+    } else if (!cross_partition_spmd_) {
+      const auto& replication = param->parameter_replicated_at_leaf_buffers();
+      int leaf_index = 0;
+      ShapeUtil::ForEachSubshape(
+          param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
+            if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
+              return Status::OK();
+            }
+            if (replication && replication->at(leaf_index)) {
+              *shape_tree.mutable_element(index) = true;
+            }
+            ++leaf_index;
+            return Status::OK();
+          });
+    }
     hlo_replication_[param] = std::move(shape_tree);
   }
   ComputeHloReplicationOnComputation(entry,
@@ -281,17 +317,18 @@ bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
 }
 
 /* static */ StatusOr>
-HloReplicationAnalysis::Run(const HloModule* module) {
+HloReplicationAnalysis::Run(const HloModule* module,
+                            bool cross_partition_spmd) {
   const absl::flat_hash_set empty;
-  return Run(module, &empty);
+  return Run(module, cross_partition_spmd, &empty);
 }
 
 /* static */ StatusOr>
-HloReplicationAnalysis::Run(const HloModule* module,
+HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd,
                             const absl::flat_hash_set*
                                 loops_known_with_same_iterations) {
-  auto analysis = absl::WrapUnique(
-      new HloReplicationAnalysis(module, loops_known_with_same_iterations));
+  auto analysis = absl::WrapUnique(new HloReplicationAnalysis(
+      module, cross_partition_spmd, loops_known_with_same_iterations));
   analysis->ComputeHloReplication();
   return analysis;
 }
diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.h b/tensorflow/compiler/xla/service/hlo_replication_analysis.h
index 3175fc35102..18b2363e454 100644
--- a/tensorflow/compiler/xla/service/hlo_replication_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.h
@@ -25,32 +25,35 @@ limitations under the License.
 namespace xla {
 
 // An HLO pass that determines whether each instruction in the module outputs
-// the same value across replicas. It propagates sources of replicated values to
+// the same value across replicas or across partitions (depending on the value
+// `cross_partition_spmd`). It propagates sources of replicated values to
 // the rest of the module, where sources include cross-replica-sum, annotated
 // entry parameters, and constants.
 class HloReplicationAnalysis {
  public:
   // Runs the analysis on module and returns the result or an error.
   static StatusOr> Run(
-      const HloModule* module);
+      const HloModule* module, bool cross_partition_spmd);
 
   // Same as above, but the caller can provide additional annotations: a set of
   // while loops that are known to have the same iteration counts across
-  // replicas.
+  // replicas or partitions.
   static StatusOr> Run(
-      const HloModule* module, const absl::flat_hash_set*
-                                   loops_known_with_same_iterations);
+      const HloModule* module, bool cross_partition_spmd,
+      const absl::flat_hash_set*
+          loops_known_with_same_iterations);
 
   // Returns if the HLO instruction outputs the same value (i.e., replicated) at
-  // the given index across all replicas.
+  // the given index across all replicas or partitions.
   bool HloInstructionIsReplicatedAt(const HloInstruction* inst,
                                     const ShapeIndex& index) const;
 
  private:
-  HloReplicationAnalysis(const HloModule* module,
+  HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd,
                          const absl::flat_hash_set*
                              loops_known_with_same_iterations)
       : module_(module),
+        cross_partition_spmd_(cross_partition_spmd),
         loops_known_with_same_iterations_(*loops_known_with_same_iterations) {}
 
   // Computes hlo_replication_.
@@ -63,14 +66,25 @@ class HloReplicationAnalysis {
 
   const HloModule* module_;
 
+  // If true, run this replication analysis for replicated values across
+  // partitions (not across replicas) on an SPMD partitioned module. This means
+  // that HloInstructionIsReplicatedAt() returns true if the value is identical
+  // across partitions for each replica. The module-level parameter and root
+  // instructions may have HloSharding attributes that indicate whether values
+  // are identical across partitions.
+  //
+  // If false, HloReplicationAnalysis runs across replicas.
+  bool cross_partition_spmd_;
+
   // A set of while loops that are known to have the same iteration counts
-  // across replicas. This is provided by the caller as additional annotations.
+  // across replicas or partitions. This is provided by the caller as additional
+  // annotations.
   const absl::flat_hash_set&
       loops_known_with_same_iterations_;
 
   // A map from each analyzed HLO instruction to a shape tree that represents
-  // whether the instruction outputs the same value across replicas at each
-  // shape index.
+  // whether the instruction outputs the same value across replicas or
+  // partitions at each shape index.
   absl::flat_hash_map> hlo_replication_;
 };
 
diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc
index 958e99dedb8..56cc8542ac4 100644
--- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc
@@ -42,16 +42,30 @@ sum {
   ROOT add.2 = f32[] add(a, b)
 }
 
+sum.u32 {
+  a = u32[] parameter(0)
+  b = u32[] parameter(1)
+  ROOT add.2 = u32[] add(a, b)
+}
+
 ENTRY entry {
   param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0)
   get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0
   get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1
   after-all.1 = token[] after-all()
+  replica-id = u32[] replica-id()
+  partition-id = u32[] partition-id()
   infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1)
   get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0
-  dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, to_apply=sum
+  dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={},
+    to_apply=sum
   subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3, all-reduce)
+  all-reduce-partitions = u32[] all-reduce(partition-id), channel_id=1,
+    to_apply=sum.u32
+  all-reduce-subgroup = u32[] all-reduce(partition-id),
+    replica_groups={{0,1},{2,3}}, to_apply=sum.u32
   ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract)
 }
 )";
@@ -62,7 +76,8 @@ ENTRY entry {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{false, true});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "get-tuple-element.2"), {}));
   EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
@@ -77,6 +92,92 @@ ENTRY entry {
       FindInstruction(module.get(), "subtract"), {}));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "add"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "replica-id"), {}));
+  EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "partition-id"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "all-reduce-partitions"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "all-reduce-subgroup"), {}));
+}
+
+TEST_F(HloReplicationAnalysisTest, NoControlFlowSPMD) {
+  const string module_str = R"(
+HloModule NoControlFlow
+
+sum {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT add.2 = f32[] add(a, b)
+}
+
+sum.u32 {
+  a = u32[] parameter(0)
+  b = u32[] parameter(1)
+  ROOT add.2 = u32[] add(a, b)
+}
+
+ENTRY entry {
+  param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0),
+    sharding={{maximal device=0}, {replicated}}
+  get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0
+  get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1
+  after-all.1 = token[] after-all()
+  replica-id = u32[] replica-id()
+  partition-id = u32[] partition-id()
+  infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1)
+  get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0
+  dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={},
+    to_apply=sum
+  all-reduce-subgroup = f32[4096,4096]{1,0} all-reduce(dot),
+    replica_groups={{0,1},{2,3}}, to_apply=sum
+  all-reduce-partitions = f32[4096,4096]{1,0} all-reduce(get-tuple-element.2),
+    channel_id=1, to_apply=sum
+  subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3,
+    all-reduce-partitions)
+  all-reduce-same-operand = u32[] all-reduce(replica-id), to_apply=sum.u32
+  all-reduce-same-operand-subgroup = u32[] all-reduce(replica-id),
+    replica_groups={{0,1},{2,3}}, to_apply=sum.u32
+  all-reduce-different-operand = u32[] all-reduce(partition-id),
+    to_apply=sum.u32
+  ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(module_str));
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr analysis,
+      HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "get-tuple-element.2"), {}));
+  EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "get-tuple-element.3"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "get-tuple-element.5"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "dot"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "all-reduce"), {}));
+  EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "subtract"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "add"), {}));
+  EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "replica-id"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "partition-id"), {}));
+  EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "all-reduce-partitions"), {}));
+  EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "all-reduce-same-operand"), {}));
+  EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "all-reduce-same-operand-subgroup"), {}));
+  EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
+      FindInstruction(module.get(), "all-reduce-different-operand"), {}));
 }
 
 TEST_F(HloReplicationAnalysisTest, NestedCall) {
@@ -111,7 +212,8 @@ ENTRY entry {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{true, false});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "get-tuple-element"), {}));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@@ -163,7 +265,8 @@ ENTRY SimpleWhileLoop {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{true, true});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "tuple"), {0}));
   EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
@@ -212,7 +315,8 @@ ENTRY WhileLoopParameterAliasingNonReplicatedOutput {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{true, true});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "multiply"), {}));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@@ -258,7 +362,8 @@ ENTRY WhileLoopDifferentCondition {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{true, true});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "while"), {0}));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@@ -307,7 +412,8 @@ ENTRY entry {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{true, true, true, true, false, true, true});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "tuple"), {0}));
   EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
@@ -371,7 +477,8 @@ ENTRY entry {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{true, true, true, true, true, true});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "tuple"), {0}));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@@ -409,7 +516,8 @@ ENTRY entry {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{true, false, true, true, true});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "tuple-select"), {0}));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@@ -435,7 +543,8 @@ ENTRY entry {
   param->set_parameter_replicated_at_leaf_buffers(
       absl::Span{true, true, true, true, false});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis,
-                          HloReplicationAnalysis::Run(module.get()));
+                          HloReplicationAnalysis::Run(
+                              module.get(), /*cross_partition_spmd=*/false));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
       FindInstruction(module.get(), "tuple-select"), {0}));
   EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(

From 17f3e8ad39919a32db99ff0666c1954d92490987 Mon Sep 17 00:00:00 2001
From: Juhyun Lee 
Date: Tue, 3 Dec 2019 13:37:51 -0800
Subject: [PATCH 236/279] Make a couple of targets in grappler buildable for
 Android.

PiperOrigin-RevId: 283612146
Change-Id: If64befd6fcd53500413c98c0978111b5abc06bcd
---
 tensorflow/core/grappler/BUILD       | 18 ++++++++++++------
 tensorflow/core/grappler/utils/BUILD | 14 ++++++++++----
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
index 3f79c023caf..fd2ea4f893c 100644
--- a/tensorflow/core/grappler/BUILD
+++ b/tensorflow/core/grappler/BUILD
@@ -23,14 +23,20 @@ cc_library(
     hdrs = ["utils.h"],
     visibility = ["//visibility:public"],
     deps = [
-        "//tensorflow/core:framework",
-        "//tensorflow/core:graph",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-        "//tensorflow/core:protos_all_cc",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
-    ],
+    ] + select({
+        "//tensorflow:android": [
+            "//tensorflow/core:android_tensorflow_lib",
+        ],
+        "//conditions:default": [
+            "//tensorflow/core:framework",
+            "//tensorflow/core:graph",
+            "//tensorflow/core:lib",
+            "//tensorflow/core:lib_internal",
+            "//tensorflow/core:protos_all_cc",
+        ],
+    }),
 )
 
 tf_cc_test(
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 8941d5552b6..7572141d415 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -386,11 +386,17 @@ cc_library(
     hdrs = ["transitive_fanin.h"],
     visibility = ["//visibility:public"],
     deps = [
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:utils",
-    ],
+    ] + select({
+        "//tensorflow:android": [
+            "//tensorflow/core:android_tensorflow_lib",
+        ],
+        "//conditions:default": [
+            "//tensorflow/core:framework",
+            "//tensorflow/core:lib",
+            "//tensorflow/core:protos_all_cc",
+        ],
+    }),
 )
 
 tf_cc_test(

From af79ee35f5bc95eb002f10f0a1a43f6f6f864a29 Mon Sep 17 00:00:00 2001
From: Amit Patankar 
Date: Tue, 3 Dec 2019 13:55:22 -0800
Subject: [PATCH 237/279] [XLA] Refactor Executable::ExecuteAsyncOnStream.

Change implementations of Executable to always implement the overload that takes a std::vector>. Make the non-owning version a wrapper around the maybe-owning version.

Simplification in preparation for plumbing buffer donation into JAX. This change is also a necessary preparatory step for implementing buffer donation on CPU and GPU.

PiperOrigin-RevId: 283615681
Change-Id: I0d3c65bee506822d23e5827493213e0921b4ef9e
---
 tensorflow/compiler/xla/service/cpu/BUILD     |  7 ----
 .../xla/service/cpu/cpu_executable.cc         | 40 ++++++------------
 .../compiler/xla/service/cpu/cpu_executable.h | 14 +++----
 tensorflow/compiler/xla/service/executable.cc | 38 ++++-------------
 tensorflow/compiler/xla/service/executable.h  | 10 ++---
 .../xla/service/gpu/gpu_executable.cc         | 39 ++++++++----------
 .../compiler/xla/service/gpu/gpu_executable.h |  9 +++-
 .../service/hlo_input_output_alias_config.cc  |  3 +-
 .../compiler/xla/service/interpreter/BUILD    |  5 ---
 .../xla/service/interpreter/executable.cc     | 41 ++++---------------
 .../xla/service/interpreter/executable.h      |  4 +-
 .../xla/service/maybe_owning_device_memory.cc |  3 +-
 .../xla/service/maybe_owning_device_memory.h  |  2 +-
 13 files changed, 69 insertions(+), 146 deletions(-)

diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index af856e92e70..229827c77c8 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -242,16 +242,9 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
         "//tensorflow/compiler/xla/service:hlo_execution_profile",
         "//tensorflow/compiler/xla/service:logical_buffer",
-        "//tensorflow/compiler/xla/service:maybe_owning_device_memory",
         "//tensorflow/compiler/xla/service:shaped_buffer",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
-        "//tensorflow/core/platform:env",
-        "//tensorflow/core/platform:logging",
-        "//tensorflow/core/platform:macros",
-        "//tensorflow/core/platform:mutex",
-        "//tensorflow/core/platform:platform_port",
-        "//tensorflow/core/platform:types",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/stream_executor:device_memory_allocator",
         "//tensorflow/stream_executor/host:host_stream",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 083c3d31d74..9b79e8ca8d7 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -32,7 +32,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/logical_buffer.h"
-#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
 #include "tensorflow/compiler/xla/shape_tree.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -45,7 +44,6 @@ limitations under the License.
 #include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/stream_executor/device_memory_allocator.h"
 #include "tensorflow/stream_executor/host/host_stream.h"
 
 namespace xla {
@@ -75,12 +73,11 @@ CpuExecutable::CpuExecutable(
           << reinterpret_cast(compute_function_);
 }
 
-StatusOr,
-                    std::vector,
-                    std::vector>>
+StatusOr,
+                   std::vector>>
 CpuExecutable::CreateBufferTable(
     se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
-    std::vector> arguments) {
+    absl::Span arguments) {
   std::vector unowning_buffers(
       assignment_->Allocations().size());
   std::vector owning_buffers(
@@ -94,9 +91,8 @@ CpuExecutable::CreateBufferTable(
     VLOG(3) << allocation.ToString();
 
     if (allocation.is_entry_computation_parameter()) {
-      unowning_buffers[i] = arguments[allocation.parameter_number()]
-                                .element(allocation.param_shape_index())
-                                .AsDeviceMemoryBase();
+      unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
+          allocation.param_shape_index());
       CHECK_EQ(allocation.size(), unowning_buffers[i].size())
           << "Size mismatch on param " << allocation.parameter_number()
           << " at shape index " << allocation.param_shape_index().ToString();
@@ -138,17 +134,7 @@ CpuExecutable::CreateBufferTable(
                       assignment_->GetUniqueTopLevelOutputSlice());
   VLOG(3) << "result index: " << result_slice.index();
 
-  std::vector buffers_to_free;
-  for (ShapeTree& argument : arguments) {
-    for (std::pair& buffer : argument) {
-      auto maybe_owning_buffer = buffer.second.Release();
-      if (maybe_owning_buffer) {
-        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
-      }
-    }
-  }
-  return {{std::move(unowning_buffers), std::move(owning_buffers),
-           std::move(buffers_to_free)}};
+  return {{std::move(unowning_buffers), std::move(owning_buffers)}};
 }
 
 Status CpuExecutable::ExecuteComputeFunction(
@@ -282,9 +268,9 @@ StatusOr CpuExecutable::CreateResultShapedBuffer(
   return std::move(result_buffer);
 }
 
-StatusOr CpuExecutable::ExecuteAsyncOnStream(
+StatusOr CpuExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    std::vector> arguments,
+    absl::Span arguments,
     HloExecutionProfile* hlo_execution_profile) {
   if (GetRootValueSet().IsAmbiguous()) {
     return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -297,7 +283,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
     for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
       const Shape& expected_shape =
           entry_comp->parameter_instruction(i)->shape();
-      const Shape& actual_shape = arguments[i].shape();
+      const Shape& actual_shape = arguments[i]->on_device_shape();
       CHECK(expected_shape == actual_shape) << absl::StreamFormat(
           "Shape mismatch on argument %d.  Expected %s, but was %s.", i,
           expected_shape.ToString(/*print_layout=*/true),
@@ -311,11 +297,10 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
   se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
   std::vector owning_buffers;
   std::vector unowning_buffers;
-  std::vector buffers_to_release;
   TF_ASSIGN_OR_RETURN(
-      std::tie(unowning_buffers, owning_buffers, buffers_to_release),
+      std::tie(unowning_buffers, owning_buffers),
       CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
-                        std::move(arguments)));
+                        arguments));
 
   TF_ASSIGN_OR_RETURN(
       ScopedShapedBuffer result,
@@ -354,8 +339,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
                        std::move(owning_buffers)),
                    hlo_execution_profile});
 
-  return ExecutionOutput(std::move(result), std::move(buffers_to_release), {},
-                         se::OwningDeviceMemory());
+  return std::move(result);
 }
 
 /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 6f8a7c3315a..37af630a2d9 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -55,9 +55,9 @@ class CpuExecutable : public Executable {
                 std::unique_ptr hlo_profile_index_map);
   ~CpuExecutable() override {}
 
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
+      absl::Span arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
   // This should be called after set_ir_module_string.
@@ -96,15 +96,11 @@ class CpuExecutable : public Executable {
   //    allocated by this routine.  This routine allocates buffers for temporary
   //    storage and the live-out buffer into which the computation writes it
   //    result.
-  //
-  //  - buffers_to_free: buffers whose ownership was donated by the caller that
-  //    are to be freed by the caller.
-  StatusOr,
-                      std::vector,
-                      std::vector>>
+  StatusOr,
+                     std::vector>>
   CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
                     int device_ordinal,
-                    std::vector> arguments);
+                    absl::Span arguments);
 
   // Calls the generated function performing the computation with the given
   // arguments using the supplied buffers.
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 9ece6172d12..c21721c9339 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -20,7 +20,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/debug_options_flags.h"
 #include "tensorflow/compiler/xla/service/dump.h"
 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
-#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -44,36 +43,9 @@ StatusOr Executable::ExecuteOnStream(
   return result;
 }
 
-static ShapeTree MakeMaybeOwningDeviceMemoryTree(
-    const ShapedBuffer& shaped_buffer) {
-  ShapeTree result(shaped_buffer.on_device_shape());
-  auto in_it = shaped_buffer.buffers().begin();
-  auto out_it = result.begin();
-  for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) {
-    DCHECK(out_it != result.end());
-    out_it->second = MaybeOwningDeviceMemory(in_it->second);
-  }
-  return result;
-}
-
-StatusOr Executable::ExecuteAsyncOnStream(
-    const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
-    HloExecutionProfile* hlo_execution_profile) {
-  std::vector> args(arguments.size());
-  auto out_it = args.begin();
-  for (const ShapedBuffer* arg : arguments) {
-    *out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
-  }
-  TF_ASSIGN_OR_RETURN(ExecutionOutput out,
-                      ExecuteAsyncOnStream(run_options, std::move(args),
-                                           hlo_execution_profile));
-  return out.ConsumeResult();
-}
-
 StatusOr Executable::ExecuteOnStream(
     const ServiceExecutableRunOptions* run_options,
-    std::vector> arguments,
+    std::vector> arguments,
     HloExecutionProfile* hlo_execution_profile) {
   StatusOr result = ExecuteAsyncOnStream(
       run_options, std::move(arguments), hlo_execution_profile);
@@ -83,6 +55,14 @@ StatusOr Executable::ExecuteOnStream(
   return result;
 }
 
+StatusOr Executable::ExecuteAsyncOnStream(
+    const ServiceExecutableRunOptions* /*run_options*/,
+    std::vector> /*arguments*/,
+    HloExecutionProfile* /*hlo_execution_profile*/) {
+  return Unimplemented(
+      "MaybeOwningDeviceMemory version of overload is not implemented ");
+}
+
 StatusOr> Executable::ExecuteOnStreams(
     absl::Span run_options,
     absl::Span> arguments) {
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 496599e7aaf..971dab95bfd 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -160,22 +160,22 @@ class Executable {
   // If the hlo_execution_profile is provided as non-nullptr, profiling will be
   // enabled. Note that profiling is tricky to use correctly, as the profiling
   // objects (when they exist) must out-live the task.
-  StatusOr ExecuteAsyncOnStream(
+  virtual StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
       absl::Span arguments,
-      HloExecutionProfile* hlo_execution_profile);
+      HloExecutionProfile* hlo_execution_profile) = 0;
 
   // Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to
   // complete.
   StatusOr ExecuteOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile);
 
   virtual StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
-      HloExecutionProfile* hlo_execution_profile) = 0;
+      std::vector> arguments,
+      HloExecutionProfile* hlo_execution_profile);
 
   // Same as ExecuteOnStream(), but runs this executable on multiple
   // streams. arguments[i] contains the arguments to the execution on
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 93af1cd995e..99bc0f7fee0 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -299,14 +299,11 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
   return &module_globals_.emplace(executor, std::move(globals)).first->second;
 }
 
-StatusOr GpuExecutable::ExecuteAsyncOnStream(
+StatusOr GpuExecutable::Execute(
     const ServiceExecutableRunOptions* run_options,
-    std::vector> arguments,
-    HloExecutionProfile* hlo_execution_profile) {
-  se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator();
-  // Force synchronous execution if the allocator requires it.
-  const bool block_host_until_done =
-      !memory_allocator->AllowsAsynchronousDeallocation();
+    absl::Span arguments,
+    HloExecutionProfile* hlo_execution_profile, bool block_host_until_done) {
+  se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
 
   if (GetRootValueSet().IsAmbiguous()) {
     return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -337,9 +334,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream(
       if (allocation.is_entry_computation_parameter()) {
         auto param_no = allocation.parameter_number();
         se::DeviceMemoryBase buffer =
-            arguments[param_no]
-                .element(allocation.param_shape_index())
-                .AsDeviceMemoryBase();
+            arguments[param_no]->buffer(allocation.param_shape_index());
 
         // All top-level buffers and sub-buffers must have an explicit, non-null
         // pointer, except for zero-sized buffers, which may be null.
@@ -428,17 +423,19 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream(
       }));
   TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
 
-  std::vector buffers_to_free;
-  for (ShapeTree& argument : arguments) {
-    for (std::pair& buffer : argument) {
-      auto maybe_owning_buffer = buffer.second.Release();
-      if (maybe_owning_buffer) {
-        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
-      }
-    }
-  }
-  return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free),
-                         {}, {});
+  return std::move(shaped_buffer);
+}
+
+StatusOr GpuExecutable::ExecuteAsyncOnStream(
+    const ServiceExecutableRunOptions* run_options,
+    absl::Span arguments,
+    HloExecutionProfile* hlo_execution_profile) {
+  se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+  // Force synchronous execution if the allocator requires it.
+  bool block_host_until_done =
+      !memory_allocator->AllowsAsynchronousDeallocation();
+  return Execute(run_options, arguments, hlo_execution_profile,
+                 block_host_until_done);
 }
 
 const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 51e86a9f8ee..66f86d768be 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -82,9 +82,9 @@ class GpuExecutable : public Executable {
 
   // ExecuteAsyncOnStream will fail if the compute capability of the stream
   // doesn't match the compute capability passed to this object's constructor.
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
+      absl::Span arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
   std::shared_ptr GetBufferAssignment() const {
@@ -92,6 +92,11 @@ class GpuExecutable : public Executable {
   }
 
  private:
+  StatusOr Execute(
+      const ServiceExecutableRunOptions* run_options,
+      absl::Span arguments,
+      HloExecutionProfile* hlo_execution_profile, bool block_host_until_done);
+
   // If `block_host_until_done` is false, execution will not block the host
   // until the kernels have completed. This is used as an optimization for
   // clients, such as Tensorflow, that use a single stream of execution for
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
index 3e82e3271bb..1c5b166a801 100644
--- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
@@ -151,8 +151,7 @@ absl::optional HloInputOutputAliasConfig::GetAliasedOutput(
 absl::optional
 HloInputOutputAliasConfig::GetAliasedParameter(
     const ShapeIndex& output_index) const {
-  CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
-      << ToString() << " " << alias_.shape().ToString() << " " << output_index;
+  CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
   return alias_.element(output_index);
 }
 
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 84c7982ad10..3073c68c975 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -89,15 +89,10 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo_evaluator",
         "//tensorflow/compiler/xla/service:hlo_execution_profile",
         "//tensorflow/compiler/xla/service:hlo_module_config",
-        "//tensorflow/compiler/xla/service:maybe_owning_device_memory",
         "//tensorflow/compiler/xla/service:shaped_buffer",
         "//tensorflow/compiler/xla/service:transfer_manager",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
-        "//tensorflow/core/platform:env",
-        "//tensorflow/core/platform:macros",
-        "//tensorflow/core/platform:mutex",
-        "//tensorflow/core/platform:types",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index f82a439fdb0..0dab86d986c 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -26,7 +26,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/interpreter/executor.h"
-#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/service/transfer_manager.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -40,39 +39,24 @@ namespace interpreter {
 InterpreterExecutable::InterpreterExecutable(
     std::unique_ptr hlo_module,
     std::unique_ptr evaluator)
-    : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
+    : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr,
                  /*hlo_profile_index_map=*/nullptr),
       evaluator_(std::move(evaluator)) {}
 
 InterpreterExecutable::~InterpreterExecutable() {}
 
-StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
+StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    std::vector> arguments,
+    absl::Span arguments,
     HloExecutionProfile* hlo_execution_profile) {
   se::Stream* stream = run_options->stream();
   se::StreamExecutor* executor = stream->parent();
   const se::Platform* platform = executor->platform();
 
-  // Convert the ShapeTree to a ShapedBuffer. We do this so we can call
-  // TransferManager methods below.
-  std::vector argument_buffers;
-  argument_buffers.reserve(arguments.size());
-  for (const ShapeTree& arg : arguments) {
-    argument_buffers.push_back(ShapedBuffer(arg.shape(), arg.shape(),
-                                            /*platform=*/nullptr,
-                                            /*device_ordinal=*/0));
-    auto in_it = arg.begin();
-    auto out_it = argument_buffers.back().buffers().begin();
-    for (; in_it != arg.end(); ++in_it, ++out_it) {
-      out_it->second = in_it->second.AsDeviceMemoryBase();
-    }
-  }
-
   VLOG(1) << "Execute " << module().name();
   if (VLOG_IS_ON(2)) {
-    for (const auto& a : argument_buffers) {
-      VLOG(2) << "-- argument " << a;
+    for (const auto& a : arguments) {
+      VLOG(2) << "-- argument " << *a;
     }
   }
 
@@ -87,7 +71,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
   // Check that the args have the right shape.
   for (int64 i = 0; i < computation->num_parameters(); ++i) {
     const auto& expected_shape = computation->parameter_instruction(i)->shape();
-    const auto& actual_shape = argument_buffers[i].on_device_shape();
+    const auto& actual_shape = arguments[i]->on_device_shape();
     if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
                                                    actual_shape)) {
       return InvalidArgument(
@@ -106,7 +90,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
   for (int64 p = 0; p < computation->num_parameters(); ++p) {
     TF_ASSIGN_OR_RETURN(Literal arg_literal,
                         transfer_manager->TransferLiteralFromDevice(
-                            run_options->stream(), argument_buffers[p]));
+                            run_options->stream(), *arguments[p]));
     arg_literals.push_back(std::move(arg_literal));
   }
 
@@ -135,16 +119,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
     profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
   }
 
-  std::vector buffers_to_free;
-  for (ShapeTree& argument : arguments) {
-    for (std::pair& buffer : argument) {
-      auto maybe_owning_buffer = buffer.second.Release();
-      if (maybe_owning_buffer) {
-        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
-      }
-    }
-  }
-  return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {});
+  return std::move(result);
 }
 
 /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h
index 1bea6773fdd..ba010de76bd 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.h
+++ b/tensorflow/compiler/xla/service/interpreter/executable.h
@@ -46,9 +46,9 @@ class InterpreterExecutable : public Executable {
                         std::unique_ptr evaluator);
   ~InterpreterExecutable() override;
 
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
+      absl::Span arguments,
       HloExecutionProfile* hlo_execution_profile) override
       LOCKS_EXCLUDED(evaluator_lock_);
 
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
index c4bf48bcc00..5fe5fea71ac 100644
--- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
@@ -17,8 +17,7 @@ limitations under the License.
 #include "absl/types/variant.h"
 namespace xla {
 
-tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase()
-    const {
+tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() {
   if (HasOwnership()) {
     return *absl::get(mem_);
   } else {
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
index 7d23d178130..8edd64cf681 100644
--- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
@@ -49,7 +49,7 @@ class MaybeOwningDeviceMemory {
 
   // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The
   // caller of this function is *not* responsible for freeing the memory.
-  tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase() const;
+  tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase();
 
   // Release the tensorflow::se::OwningDeviceMemory without freeing it, and
   // moves the ownership of the memory buffer from the object to the caller.

From d0acd1e26780610552aa2a974ab7b666d682a2c7 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 13:58:42 -0800
Subject: [PATCH 238/279] PR #27825: TFLite: Div op Neon optimization

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/27825

Added float32 division optimized with Neon SIMD instructions.
Copybara import of the project:

--
084000813642779063a1701b621e86823da5121b by Michal W. Tarnowski :

Non-broadcast Div optimized

--
43a06104a6f3ad4e93099b7f1750948056dd47a7 by Michal W. Tarnowski :

Explicit NEON typenames removed

--
4d9297306254d...

***

ROLLBACK_OF=283557872

BEGIN_PUBLIC

PiperOrigin-RevId: 283616376
Change-Id: I66eeafd640d1d7342877453c52459dec731141ef
---
 .../internal/optimized/optimized_ops.h        | 83 -------------------
 1 file changed, 83 deletions(-)

diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index d4512409096..26005e069a7 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -2718,89 +2718,6 @@ inline void BroadcastMulDispatch(
                        input2_data, output_shape, output_data);
 }
 
-inline void Div(const ArithmeticParams& params,
-                const RuntimeShape& input1_shape, const float* input1_data,
-                const RuntimeShape& input2_shape, const float* input2_data,
-                const RuntimeShape& output_shape, float* output_data) {
-  gemmlowp::ScopedProfilingLabel label("Div");
-  const float output_activation_min = params.float_activation_min;
-  const float output_activation_max = params.float_activation_max;
-
-  int i = 0;
-  const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
-#ifdef USE_NEON
-  // NEON does not offer division instruction, multiplication by the reciprocal
-  // is used instead. This parameter controls the number of Newton-Raphson
-  // iterations used to refine the initial estimate of the reciprocal given by
-  // vrecpeq_f32 instruction. Typically, two iterations are enough to match
-  // the float division accuracy closely.
-  static constexpr int kNewtonSteps = 2;
-  static const auto TWO_F32 = vdupq_n_f32(2.f);
-  const auto activation_min = vdupq_n_f32(output_activation_min);
-  const auto activation_max = vdupq_n_f32(output_activation_max);
-  for (; i <= size - 16; i += 16) {
-    const auto a10 = vld1q_f32(input1_data + i);
-    const auto a11 = vld1q_f32(input1_data + i + 4);
-    const auto a12 = vld1q_f32(input1_data + i + 8);
-    const auto a13 = vld1q_f32(input1_data + i + 12);
-    const auto a20 = vld1q_f32(input2_data + i);
-    const auto a21 = vld1q_f32(input2_data + i + 4);
-    const auto a22 = vld1q_f32(input2_data + i + 8);
-    const auto a23 = vld1q_f32(input2_data + i + 12);
-
-    auto r0 = vrecpeq_f32(a20);
-    auto r1 = vrecpeq_f32(a21);
-    auto r2 = vrecpeq_f32(a22);
-    auto r3 = vrecpeq_f32(a23);
-    for (int k = 0; k < kNewtonSteps; ++k) {
-      r0 = vmulq_f32(r0, vsubq_f32(TWO_F32, vmulq_f32(r0, a20)));
-      r1 = vmulq_f32(r1, vsubq_f32(TWO_F32, vmulq_f32(r1, a21)));
-      r2 = vmulq_f32(r2, vsubq_f32(TWO_F32, vmulq_f32(r2, a22)));
-      r3 = vmulq_f32(r3, vsubq_f32(TWO_F32, vmulq_f32(r3, a23)));
-    }
-
-    auto x0 = vmulq_f32(a10, r0);
-    auto x1 = vmulq_f32(a11, r1);
-    auto x2 = vmulq_f32(a12, r2);
-    auto x3 = vmulq_f32(a13, r3);
-    x0 = vmaxq_f32(activation_min, x0);
-    x1 = vmaxq_f32(activation_min, x1);
-    x2 = vmaxq_f32(activation_min, x2);
-    x3 = vmaxq_f32(activation_min, x3);
-    x0 = vminq_f32(activation_max, x0);
-    x1 = vminq_f32(activation_max, x1);
-    x2 = vminq_f32(activation_max, x2);
-    x3 = vminq_f32(activation_max, x3);
-
-    vst1q_f32(output_data + i, x0);
-    vst1q_f32(output_data + i + 4, x1);
-    vst1q_f32(output_data + i + 8, x2);
-    vst1q_f32(output_data + i + 12, x3);
-  }
-  for (; i <= size - 4; i += 4) {
-    const auto a1 = vld1q_f32(input1_data + i);
-    const auto a2 = vld1q_f32(input2_data + i);
-
-    auto r = vrecpeq_f32(a2);
-    for (int k = 0; k < kNewtonSteps; ++k) {
-      r = vmulq_f32(r, vsubq_f32(TWO_F32, vmulq_f32(r, a2)));
-    }
-
-    auto x = vmulq_f32(a1, r);
-    x = vmaxq_f32(activation_min, x);
-    x = vminq_f32(activation_max, x);
-
-    vst1q_f32(output_data + i, x);
-  }
-#endif  // NEON
-
-  for (; i < size; ++i) {
-    output_data[i] = ActivationFunctionWithMinMax(
-        input1_data[i] / input2_data[i], output_activation_min,
-        output_activation_max);
-  }
-}
-
 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
 // dimensionality if the runtime code does a single loop over one dimension
 // that handles broadcasting as the base case. The code generator would then

From 8f467074608f328f8e5becc754ab271847aa2941 Mon Sep 17 00:00:00 2001
From: Sean Silva 
Date: Tue, 3 Dec 2019 14:00:36 -0800
Subject: [PATCH 239/279] Make diagnostic a bit clearer.

This prints out in case of any pass failure. Not just a crash.

PiperOrigin-RevId: 283616719
Change-Id: I31ee68cd17dcc3867f7a5e6a1bf21ca336cecc63
---
 third_party/mlir/lib/Pass/Pass.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/third_party/mlir/lib/Pass/Pass.cpp b/third_party/mlir/lib/Pass/Pass.cpp
index a195bb0c0c8..6d8e230eeec 100644
--- a/third_party/mlir/lib/Pass/Pass.cpp
+++ b/third_party/mlir/lib/Pass/Pass.cpp
@@ -533,7 +533,7 @@ static LogicalResult runWithCrashRecovery(OpPassManager &pm,
   outputFile->keep();
 
   return reproducerModule->emitError()
-         << "A crash has been detected while processing the MLIR module, a "
+         << "A failure has been detected while processing the MLIR module, a "
             "reproducer has been generated in '"
          << crashReproducerFileName << "'";
 }

From 1c81d32696a57ffb9365791ee19f94caba2ffb1d Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Tue, 3 Dec 2019 14:30:41 -0800
Subject: [PATCH 240/279] Add legalization from tf.TopKV2 to XLA HLO ops

- Tightened tf.TopKV2 verification
- Defined xla_hlo.sort operation
- Added lowering from tf.TopKV2 to XLA HLO ops

PiperOrigin-RevId: 283623396
Change-Id: Ia705c72022452617ab1209532c6408c3cb399a9c
---
 .../mlir/tensorflow/ir/tf_generated_ops.td    |   2 +
 .../compiler/mlir/tensorflow/ir/tf_ops.cc     |  15 ++
 .../mlir/tensorflow/tests/tf-ops.mlir         |  16 +++
 tensorflow/compiler/mlir/xla/ir/hlo_ops.cc    |  64 +++++++++
 tensorflow/compiler/mlir/xla/ir/hlo_ops.td    |  20 +++
 .../compiler/mlir/xla/ir/hlo_ops_base.td      |  11 ++
 .../compiler/mlir/xla/mlir_hlo_to_hlo.cc      |  12 ++
 .../compiler/mlir/xla/tests/legalize-tf.mlir  |  39 +++++
 tensorflow/compiler/mlir/xla/tests/ops.mlir   |  95 ++++++++++++
 .../mlir/xla/tests/translate/export.mlir      |  18 +++
 .../mlir/xla/transforms/legalize_tf.cc        | 135 +++++++++++++++++-
 11 files changed, 423 insertions(+), 4 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 57b61461d02..cdc545d5681 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -5768,6 +5768,8 @@ If two elements are equal, the lower-index element appears first.
   );
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+
+  let verifier = [{ return Verify(*this); }];
 }
 
 def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 3b836a6188d..1bd9accbb78 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -1698,6 +1698,21 @@ static LogicalResult Verify(TensorListStackOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TopKV2Op
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(TopKV2Op op) {
+  if (!HasRankAtLeast(op.input(), 1))
+    return op.emitOpError(
+        "requires input operand to have at least 1 dimension");
+
+  if (!IsOfRankOrUnranked(op.k(), 0))
+    return op.emitOpError("requires k operand to be 0D tensor");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index 1914ca177cc..e064c1a53ef 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -1658,3 +1658,19 @@ func @testTernaryEinsum(%arg0: tensor<2x3xf32>){
   %0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>)
   return
 }
+
+// -----
+
+func @testTopKV2WrongInputRank(%input: tensor, %k: tensor) {
+  // expected-error @+1 {{op requires input operand to have at least 1 dimension}}
+  %0:2 = "tf.TopKV2"(%input, %k) : (tensor, tensor) -> (tensor<*xf32>, tensor<*xi32>)
+  return
+}
+
+// -----
+
+func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) {
+  // expected-error @+1 {{op requires k operand to be 0D tensor}}
+  %0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>)
+  return
+}
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 8fa33d19363..b2f02bdf76f 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -841,6 +841,70 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
   return RankedTensorType::get(shape, ranked_ty.getElementType());
 }
 
+//===----------------------------------------------------------------------===//
+// SortOp
+//===----------------------------------------------------------------------===//
+
+void SortOp::build(Builder* builder, OperationState& state,
+                   ArrayRef operands, int64_t dimension,
+                   bool is_stable) {
+  state.addOperands(operands);
+  state.addAttribute("dimension", builder->getI64IntegerAttr(dimension));
+  state.addAttribute("is_stable", builder->getBoolAttr(dimension));
+
+  SmallVector element_types;
+  element_types.reserve(operands.size());
+  for (Value* operand : operands) element_types.push_back(operand->getType());
+  state.addTypes(builder->getTupleType(element_types));
+
+  state.addRegion();
+}
+
+static LogicalResult Verify(SortOp op) {
+  Operation::operand_range operands = op.operands();
+  if (operands.empty()) return op.emitOpError("requires at least one input");
+
+  // TODO(antiagainst): verify partionally dynamic shapes
+  if (llvm::all_of(operands, [](Value* operand) {
+        return operand->getType().cast().hasRank();
+      })) {
+    ArrayRef input_shape =
+        (*operands.begin())->getType().cast().getShape();
+
+    if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) {
+          return operand->getType().cast().getShape() !=
+                 input_shape;
+        }))
+      return op.emitOpError("requires all inputs to have the same dimensions");
+
+    if (op.dimension().getSExtValue() >= input_shape.size())
+      return op.emitOpError(
+          "dimension attribute value must be less than input rank");
+  }
+
+  Block& block = op.comparator().front();
+  size_t num_operands = op.getOperation()->getNumOperands();
+  if (block.getNumArguments() != 2 * num_operands)
+    return op.emitOpError("comparator block should have ")
+           << 2 * num_operands << " arguments";
+
+  for (auto indexed_operand : llvm::enumerate(operands)) {
+    int index = indexed_operand.index();
+    Type element_type =
+        indexed_operand.value()->getType().cast().getElementType();
+    Type tensor_type = RankedTensorType::get({}, element_type);
+    for (int i : {2 * index, 2 * index + 1}) {
+      Type arg_type = block.getArgument(i)->getType();
+      if (arg_type != tensor_type)
+        return op.emitOpError("comparator block argument #")
+               << i << " should be of type " << tensor_type << " but got "
+               << arg_type;
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index e285b172806..c9b3e7985fc 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -868,6 +868,26 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
   let hasCustomHLOConverter = 1;
 }
 
+def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp {
+  let arguments = (ins
+    Variadic:$operands,
+    DefaultValuedAttr:$dimension,
+    DefaultValuedAttr:$is_stable
+  );
+
+  let results = (outs HLO_TensorOrTuple);
+
+  let regions = (region SizedRegion<1>:$comparator);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState &state, ArrayRef operands, "
+    "int64_t dimension, bool is_stable"
+  >];
+
+  // TODO(b/129422361): SortOp has special conversion logic to HLO.
+  let hasCustomHLOConverter = 1;
+}
+
 def HLO_ReverseOp: HLO_Op<"reverse",
       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp {
   let arguments = (ins
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
index a0c790616fa..a6d4210b60c 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
@@ -832,6 +832,17 @@ class BASE_HLO_SelectAndScatterOp {
   }];
 }
 
+class BASE_HLO_SortOp {
+  string summary = "Sort operator";
+
+  string description = [{
+    Sorts the given `operands` at the given `dimension` with the given
+    `comparator`.
+
+    See https://www.tensorflow.org/xla/operation_semantics#sort.
+  }];
+}
+
 class BASE_HLO_ReverseOp {
   string summary = "Reverse operator";
 
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 93716331d0d..e9bf3bac44b 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -624,6 +624,18 @@ LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
   return failure();
 }
 
+LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
+  xla::XlaComputation comparator;
+  if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
+                                                     &comparator)))
+    return failure();
+
+  auto& value_map = *ctx.values;
+  value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator,
+                            op.dimension().getSExtValue(), op.is_stable());
+  return success();
+}
+
 LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
   auto& value_map = *ctx.values;
   value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx));
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 94a445fe8bd..8aa9b5ef101 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1934,3 +1934,42 @@ func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf
   // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
   return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
 }
+
+//===----------------------------------------------------------------------===//
+// tf.TopKV2 legalization
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: topk_v2_non_const_k
+func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor) -> (tensor, tensor) {
+  // CHECK: tf.TopKV2
+  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor) -> (tensor, tensor)
+  return %0#0, %0#1: tensor, tensor
+}
+
+// CHECK-LABEL: topk_v2_unknown_input_last_dim
+func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
+  %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor
+  // CHECK: tf.TopKV2
+  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor) -> (tensor<16x?xf32>, tensor<16x?xi32>)
+  return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
+}
+
+// CHECK-LABEL: topk_v2
+// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
+func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
+  %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor
+
+  // CHECK:      %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64}
+  // CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
+  // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor):
+  // CHECK-NEXT:   %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"}
+  // CHECK-NEXT:   "xla_hlo.return"(%[[CMP]])
+  // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  // CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32}
+  // CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32}
+  // CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+  // CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+  // CHECK-NEXT: return %[[VAL]], %[[IDX]]
+  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>)
+  return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir
index 225fc97bb22..4f142f294e4 100644
--- a/tensorflow/compiler/mlir/xla/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir
@@ -416,3 +416,98 @@ func @constants() -> () {
   %3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor<*xi32>)
   return
 }
+
+// -----
+
+func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
+  // CHECK: xla_hlo.sort
+  %0 = "xla_hlo.sort"(%input0, %input1) ( {
+  ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor):
+    %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  return
+}
+
+// -----
+
+func @sort_no_operands() {
+  // expected-error @+1 {{op requires at least one input}}
+  %0 = "xla_hlo.sort"() ( {
+  ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor):
+    %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : () -> tuple<>
+  return
+}
+
+// -----
+
+func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
+  %0 = "xla_hlo.sort"(%input0, %input1) ( {
+  ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor):
+    %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  return
+}
+
+// -----
+
+func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
+  // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}}
+  %0 = "xla_hlo.sort"(%input0, %input1) ( {
+  ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor):
+    %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  return
+}
+
+// -----
+
+func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
+  // expected-error @+1 {{op requires all inputs to have the same dimensions}}
+  %0 = "xla_hlo.sort"(%input0, %input1) ( {
+  ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor):
+    %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  return
+}
+
+// -----
+
+func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
+  // expected-error @+1 {{op dimension attribute value must be less than input rank}}
+  %0 = "xla_hlo.sort"(%input0, %input1) ( {
+  ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor):
+    %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  return
+}
+
+// -----
+
+func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
+  // expected-error @+1 {{op comparator block should have 4 arguments}}
+  %0 = "xla_hlo.sort"(%input0, %input1) ( {
+  ^bb0(%arg0: tensor, %arg1: tensor):
+    %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  return
+}
+
+// -----
+
+func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
+  // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}}
+  %0 = "xla_hlo.sort"(%input0, %input1) ( {
+  ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor):
+    %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  return
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
index 70b48fa43c9..ffcc1cc9df3 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -620,3 +620,21 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
   // CHECK:  ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
   return %0 : tensor<4xi1>
 }
+
+// -----
+
+// CHECK-LABEL:  HloModule
+func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
+  %0 = "xla_hlo.sort"(%input0, %input1) ( {
+  ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor):
+    %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor
+    "xla_hlo.return"(%7) : (tensor) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>>
+  return
+}
+
+// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] {
+// CHECK:   ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
+
+// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) {
+// CHECK:   ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index d7f3bf243e5..f0ba67e2fd5 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -438,6 +438,38 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
   return GetI64ElementsAttr(normalized_sizes, builder);
 }
 
+//===----------------------------------------------------------------------===//
+// Sort op utilities.
+//===----------------------------------------------------------------------===//
+
+// Builds the region `body` for xla_hlo.sort's comparator: for each type in
+// `element_types`, create two block arguments, one for lhs and one for rhs, and
+// generates xla_hlo.compare op to compare them with the given `direction`.
+//
+// Note that this right now only does comparsion on the first pair of block
+// arguments.
+static void BuildSortComparisonBody(llvm::ArrayRef element_types,
+                                    StringRef direction, Region *body,
+                                    OpBuilder *builder) {
+  OpBuilder::InsertionGuard insertion_point_gurad(*builder);
+
+  Block *block = builder->createBlock(body);
+  // Add two arguments for each element type.
+  for (Type element_type : element_types) {
+    TensorType tensor_type = RankedTensorType::get({}, element_type);
+    block->addArguments({tensor_type, tensor_type});
+  }
+
+  Location loc = body->getLoc();
+  StringAttr compare_direction =
+      StringAttr::get(direction, builder->getContext());
+  Value *compare = builder->create(
+      loc, block->getArgument(0), block->getArgument(1),
+      /*broadcast_dimensions=*/nullptr, compare_direction);
+
+  builder->create(loc, compare);
+}
+
 //===----------------------------------------------------------------------===//
 // Op converters.
 //===----------------------------------------------------------------------===//
@@ -1873,6 +1905,101 @@ class ConvertOneHotOp : public OpRewritePattern {
   }
 };
 
+// Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
+//
+// tf.TopKV2 sorts along last dimension of the input tensor and then returns
+// the top K components' values and indices. This is translated into a few
+// ops in XLA HLO: first generating an integer sequence for the indices,
+// then sort both the original input tensor and the indices togheter, and
+// at last slice out the top K components.
+//
+// For example, for the following IR:
+//
+// %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor
+// %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) ->
+//                                 (tensor<16x8xf32>, tensor<16x8xi32>)
+//
+// We will get:
+//
+// %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
+// %2 = "xla_hlo.sort"(%input, %1) ( {
+// ^bb0(%arg1: tensor, %arg2: tensor,
+//      %arg3: tensor, %arg4: tensor):
+//   %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
+//   "xla_hlo.return"(%7) : (tensor) -> ()
+// }) {dimension = 1 : i64, is_stable = true} : ...
+// %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
+// %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
+// %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
+//                           start_indices dense<0> : tensor<2xi64>,
+//                           strides = dense<1> : tensor<2xi64>} :
+//                              (tensor<16x16xf32>) -> tensor<16x8xf32>
+// %6 = "xla_hlo.slice"(%4) ...
+class ConvertTopKV2Op : public OpRewritePattern {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(TF::TopKV2Op op,
+                                     PatternRewriter &rewriter) const override {
+    // We can only match when the `k` operand is a constant scalar.
+    DenseIntElementsAttr k_attr;
+    if (!matchPattern(op.k(), m_Constant(&k_attr))) return matchFailure();
+
+    // The last dimension of the input tensor's shape should be known so we can
+    // have clamped end_indices for slices.
+    TensorType input_type = op.input()->getType().cast();
+    if (!input_type.hasRank()) return matchFailure();
+    int64_t input_rank = input_type.getRank();
+    int64_t last_dim_index = input_rank - 1;
+    int64_t last_dim_size = input_type.getDimSize(last_dim_index);
+    if (last_dim_size == ShapedType::kDynamicSize) return matchFailure();
+
+    // Create an Itoa op for indices.
+    auto i32_type = rewriter.getIntegerType(32);
+    Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
+    Value *iota_op = rewriter.create(
+        op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
+
+    // Create the sort op. It takes two inputs, one for the original input, the
+    // other for the indices.
+    auto sort_op = rewriter.create(
+        op.getLoc(), llvm::ArrayRef{op.input(), iota_op},
+        last_dim_index, /*is_stable=*/true);
+    BuildSortComparisonBody({input_type.getElementType(), i32_type},
+                            /*direction=*/"GT", &sort_op.comparator(),
+                            &rewriter);
+
+    // Get the sorted input and index tuple element.
+    auto tuple_first_element =
+        rewriter.create(op.getLoc(), sort_op, 0);
+    auto tuple_second_element =
+        rewriter.create(op.getLoc(), sort_op, 1);
+
+    SmallVector begin_indices(input_rank, 0);
+    auto end_indices = llvm::to_vector<4>(input_type.getShape());
+    end_indices.back() =
+        std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
+    SmallVector strides(input_rank, 1);
+
+    // Get the slice for the top K elements.
+
+    Value *values = rewriter.create(
+        op.getLoc(), tuple_first_element,
+        GetI64ElementsAttr(begin_indices, &rewriter),
+        GetI64ElementsAttr(end_indices, &rewriter),
+        GetI64ElementsAttr(strides, &rewriter));
+
+    Value *indices = rewriter.create(
+        op.getLoc(), tuple_second_element,
+        GetI64ElementsAttr(begin_indices, &rewriter),
+        GetI64ElementsAttr(end_indices, &rewriter),
+        GetI64ElementsAttr(strides, &rewriter));
+
+    rewriter.replaceOp(op, {values, indices});
+    return matchSuccess();
+  }
+};
+
 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
 
 LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
@@ -1892,10 +2019,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
               ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp,
               ConvertSigmoidOp, ConvertSoftmaxOp,
               ConvertSoftmaxOp, ConvertSplitOp,
-              ConvertStridedSliceOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp,
-              ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp,
-              ConvertConv2DBackpropInputOp, ConvertConv2DBackpropFilterOp>(
-          op->getContext());
+              ConvertStridedSliceOp, ConvertTopKV2Op, ConvertMeanOp,
+              ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp,
+              ConvertOneHotOp, ConvertConv2DBackpropInputOp,
+              ConvertConv2DBackpropFilterOp>(op->getContext());
 
   ConversionTarget target(*context);
   target.addLegalDialect();

From 8cf05f47a5c2c8bf40e6e2ee75267fde7f83d107 Mon Sep 17 00:00:00 2001
From: Shanqing Cai 
Date: Tue, 3 Dec 2019 14:31:17 -0800
Subject: [PATCH 241/279] [tfdbg] Shard distributed_callbacks_test to reduce
 likelihood of timeouts

PiperOrigin-RevId: 283623507
Change-Id: I021d13454667d2e92495a5361195b899cdc6eea9
---
 tensorflow/python/debug/BUILD | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 1c30328e7dd..97fe48ee165 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -725,6 +725,7 @@ cuda_py_test(
         "//tensorflow/python/keras",
     ],
     python_version = "PY3",
+    shard_count = 8,
     tags = [
         "guitar",
         "multi_and_single_gpu",

From 0365083580cdfc786a9c5a29eb33fc1144f74eb2 Mon Sep 17 00:00:00 2001
From: HyoukJoong Lee 
Date: Tue, 3 Dec 2019 14:37:30 -0800
Subject: [PATCH 242/279] Infer sharding from neighbors only for maximal
 sharding

PiperOrigin-RevId: 283624742
Change-Id: I290f54778c2c5406d4161045f6e8d3df39ce96b1
---
 tensorflow/compiler/tf2xla/tf2xla_util.cc | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index e82546def46..8dc44eac51a 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -503,8 +503,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
         ParseShardingFromDevice(
             *possible_match,
             /*num_cores_per_replica=*/std::numeric_limits::max()));
-    if (sharding.has_value()) {
-      TF_RET_CHECK(sharding.value().type() == xla::OpSharding::MAXIMAL);
+    if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
       const int core_annotation = sharding.value().tile_assignment_devices(0);
       if (core == -1 || core > core_annotation) {
         core = core_annotation;

From d68fa265867de3bf8e1e79c0504849f314ca67c7 Mon Sep 17 00:00:00 2001
From: Haoyu Zhang 
Date: Tue, 3 Dec 2019 14:37:49 -0800
Subject: [PATCH 243/279] Add tf.lite tests for tf.signal.frame.

PiperOrigin-RevId: 283624804
Change-Id: I6fd2f5234fb474de3cc5490e926938f76ffe3b5b
---
 .../kernel_tests/signal/shape_ops_test.py     | 46 +------------------
 1 file changed, 1 insertion(+), 45 deletions(-)

diff --git a/tensorflow/python/kernel_tests/signal/shape_ops_test.py b/tensorflow/python/kernel_tests/signal/shape_ops_test.py
index e9056accc71..6d9c77a0136 100644
--- a/tensorflow/python/kernel_tests/signal/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/shape_ops_test.py
@@ -18,9 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import itertools
-
-from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.eager import context
@@ -28,7 +25,6 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util as tf_test_util
 from tensorflow.python.kernel_tests.signal import test_util
 from tensorflow.python.ops import array_ops
@@ -38,7 +34,7 @@ from tensorflow.python.platform import test
 
 
 @tf_test_util.run_all_in_graph_and_eager_modes
-class FrameTest(test.TestCase, parameterized.TestCase):
+class FrameTest(test.TestCase):
 
   def test_mapping_of_indices_without_padding(self):
     tensor = constant_op.constant(np.arange(9152), dtypes.int32)
@@ -356,46 +352,6 @@ class FrameTest(test.TestCase, parameterized.TestCase):
         rewritten_graph = test_util.grappler_optimize(g, [frames])
         self.assertEqual(1, len(rewritten_graph.node))
 
-  @parameterized.parameters(
-      itertools.product(
-          # length % step == 0
-          ((32, 16),
-           # gcd(length, step) == 1
-           (32, 15),
-           # gcd(length, step) == 5
-           (25, 15),
-           # length == step
-           (32, 32)),
-          (False, True),  # pad_end
-          (False, True),  # use_mlir
-          (False, True)))  # known_batch
-  def test_tflite_convert(self, length_step, pad_end, use_mlir, known_batch):
-    """Check for tf.lite compatibility in a variety of settings."""
-    def fn(signal):
-      return shape_ops.frame(
-          signal, length_step[0], length_step[1], pad_end=pad_end)
-
-    # TODO(b/144998258): unknown batch does not currently work with padding.
-    if not known_batch and pad_end:
-      return
-
-    signal_length, dtype = 8001, dtypes.float32
-    # If batch size is unknown, tf.lite assumes it's 1. Test batch_size > 1
-    # only when batch size is known.
-    batch_size = 2 if known_batch else 1
-    static_batch_size = batch_size if known_batch else None
-    tflite_model = test_util.tflite_convert(
-        fn, [tensor_spec.TensorSpec(
-            shape=[static_batch_size, signal_length], dtype=dtype)],
-        use_mlir)
-    signal = np.random.normal(size=(batch_size, signal_length)).astype(
-        dtype.as_numpy_dtype)
-    actual_output, = test_util.evaluate_tflite_model(
-        tflite_model, [signal])
-
-    expected_output = self.evaluate(fn(signal))
-    self.assertAllClose(actual_output, expected_output)
-
 
 if __name__ == "__main__":
   test.main()

From 365ca18152ca5cd15dd3b96914fb1221dc8c1b83 Mon Sep 17 00:00:00 2001
From: Alexandre Passos 
Date: Tue, 3 Dec 2019 14:44:49 -0800
Subject: [PATCH 244/279] Fix floating point golden test for sigmoid.

PiperOrigin-RevId: 283626203
Change-Id: I64ed3747b40e18c09520556e642b5826367cbd4e
---
 tensorflow/python/keras/activations.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index 55440dd4017..17af5d36b41 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -260,9 +260,8 @@ def sigmoid(x):
 
   >>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
   >>> b = tf.keras.activations.sigmoid(a)
-  >>> b.numpy()
-  array([0.        , 0.26894143, 0.5       , 0.7310586 , 1.        ],
-         dtype=float32)
+  >>> b.numpy() > 0.0
+  array([False,  True,  True,  True,  True])
 
   Arguments:
       x: Input tensor.

From f81e64ca009675c9bc07ef5b230fa25a59e1d806 Mon Sep 17 00:00:00 2001
From: Dan Moldovan 
Date: Tue, 3 Dec 2019 14:53:28 -0800
Subject: [PATCH 245/279] Whitelist known pure-Python library - PIL.

PiperOrigin-RevId: 283627816
Change-Id: I40f7aca7530008c6cbc87b5865d7cb51aba927c7
---
 tensorflow/python/autograph/core/config.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tensorflow/python/autograph/core/config.py b/tensorflow/python/autograph/core/config.py
index 41d05ce6502..b336ea771d3 100644
--- a/tensorflow/python/autograph/core/config.py
+++ b/tensorflow/python/autograph/core/config.py
@@ -49,6 +49,7 @@ CONVERSION_RULES = (
     # Known libraries
     DoNotConvert('numpy'),
     DoNotConvert('tensorflow'),
+    DoNotConvert('PIL'),
 
     # TODO(b/133417201): Remove.
     DoNotConvert('tensorflow_probability'),

From 24215adab4506add5a48ac7164944d5b6fcba9e1 Mon Sep 17 00:00:00 2001
From: Juhyun Lee 
Date: Tue, 3 Dec 2019 14:54:58 -0800
Subject: [PATCH 246/279] Fix a typo: TfLiteConvParams -> TfLitePoolParams.

PiperOrigin-RevId: 283628103
Change-Id: Iffb3e324540e1d0cc1720cc72cfce805f306e13e
---
 tensorflow/lite/experimental/micro/kernels/pooling_test.cc | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tensorflow/lite/experimental/micro/kernels/pooling_test.cc b/tensorflow/lite/experimental/micro/kernels/pooling_test.cc
index d2f8f41edcd..03909b994f8 100644
--- a/tensorflow/lite/experimental/micro/kernels/pooling_test.cc
+++ b/tensorflow/lite/experimental/micro/kernels/pooling_test.cc
@@ -54,7 +54,7 @@ void TestAveragePoolingFloat(std::initializer_list input_dims_data,
       resolver.FindOp(tflite::BuiltinOperator_AVERAGE_POOL_2D, 1);
   TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
 
-  TfLiteConvParams builtin_data = {padding,      stride_width,  stride_height,
+  TfLitePoolParams builtin_data = {padding,      stride_width,  stride_height,
                                    filter_width, filter_height, activation};
   const char* init_data = reinterpret_cast(&builtin_data);
   size_t init_data_size = 0;
@@ -125,7 +125,7 @@ void TestAveragePoolingUint8(
       resolver.FindOp(tflite::BuiltinOperator_AVERAGE_POOL_2D, 1);
   TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
 
-  TfLiteConvParams builtin_data = {padding,      stride_width,  stride_height,
+  TfLitePoolParams builtin_data = {padding,      stride_width,  stride_height,
                                    filter_width, filter_height, activation};
   const char* init_data = reinterpret_cast(&builtin_data);
   size_t init_data_size = 0;
@@ -198,7 +198,7 @@ void TestAveragePoolingInt8(std::initializer_list input_dims_data,
       resolver.FindOp(tflite::BuiltinOperator_AVERAGE_POOL_2D, 1);
   TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
 
-  TfLiteConvParams builtin_data = {padding,      stride_width,  stride_height,
+  TfLitePoolParams builtin_data = {padding,      stride_width,  stride_height,
                                    filter_width, filter_height, activation};
   const char* init_data = reinterpret_cast(&builtin_data);
   size_t init_data_size = 0;

From 60a7aa9e9c6249d2e82ed79d3a7de2715e420e5c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 15:10:17 -0800
Subject: [PATCH 247/279] Support ROI ops in the tflite->flow converter.

PiperOrigin-RevId: 283631528
Change-Id: I0614721e9f4d508edb4c02dce6ca48f4b63e983d
---
 tensorflow/lite/delegates/gpu/common/BUILD    |  14 +++
 .../delegates/gpu/common/custom_parsers.cc    |  36 ++++++
 .../delegates/gpu/common/custom_parsers.h     |  37 ++++++
 .../delegates/gpu/common/model_builder.cc     | 117 ++++++++++++++++++
 4 files changed, 204 insertions(+)
 create mode 100644 tensorflow/lite/delegates/gpu/common/custom_parsers.cc
 create mode 100644 tensorflow/lite/delegates/gpu/common/custom_parsers.h

diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD
index 20ff67677a9..4da852b0565 100644
--- a/tensorflow/lite/delegates/gpu/common/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/BUILD
@@ -19,6 +19,19 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "custom_parsers",
+    srcs = ["custom_parsers.cc"],
+    hdrs = ["custom_parsers.h"],
+    deps = [
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:any",
+        "@flatbuffers",
+    ],
+)
+
 cc_library(
     name = "access_type",
     hdrs = ["access_type.h"],
@@ -96,6 +109,7 @@ cc_library(
     srcs = ["model_builder.cc"],
     hdrs = ["model_builder.h"],
     deps = [
+        ":custom_parsers",
         ":data_type",
         ":model",
         ":operations",
diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.cc b/tensorflow/lite/delegates/gpu/common/custom_parsers.cc
new file mode 100644
index 00000000000..d46a9247c81
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.cc
@@ -0,0 +1,36 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
+
+#include 
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/any.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+
+namespace tflite {
+namespace gpu {
+
+Status ParseCustomAttributes(absl::string_view op_name, const void* data,
+                             uint32_t data_size, absl::any* attr,
+                             BHWC* output_shape) {
+  return UnimplementedError(absl::StrCat(
+      "Attributes parsing is not enabled for ", op_name, " operation"));
+}
+
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.h b/tensorflow/lite/delegates/gpu/common/custom_parsers.h
new file mode 100644
index 00000000000..e9a191d46cb
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.h
@@ -0,0 +1,37 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
+
+#include 
+
+#include "absl/strings/string_view.h"
+#include "absl/types/any.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+
+namespace tflite {
+namespace gpu {
+
+// Matches the custom operation by the string name and parses attributes stored
+// as flexbuffers.
+Status ParseCustomAttributes(absl::string_view op_name, const void* data,
+                             uint32_t data_size, absl::any* attr,
+                             BHWC* output_shape);
+
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index 4aec15f0b67..8e33c4eeb75 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -36,6 +36,7 @@ limitations under the License.
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/context.h"
+#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
@@ -2227,6 +2228,110 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser {
   }
 };
 
+class RoIToTransformMatrixOperationParser : public TFLiteOperationParser {
+ public:
+  Status IsSupported(const TfLiteContext* context,
+                     const TfLiteNode* tflite_node,
+                     const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(
+        CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
+    return OkStatus();
+  }
+
+  Status Parse(const TfLiteNode* tflite_node,
+               const TfLiteRegistration* registration, GraphFloat32* graph,
+               ObjectReader* reader) final {
+    Node* node = graph->NewNode();
+    RETURN_IF_ERROR(reader->AddInput(node, 0));  // bbox
+    RETURN_IF_ERROR(reader->AddOutputs(node));
+
+    std::string op_name = "roi_to_transform_matrix";
+    node->operation.type = op_name;
+    BHWC output_shape;
+    RETURN_IF_ERROR(
+        ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
+                              tflite_node->custom_initial_data_size,
+                              &(node->operation.attributes), &output_shape));
+
+    auto output_value = graph->FindOutputs(node->id)[0];
+    output_value->tensor.shape = output_shape;
+    return OkStatus();
+  }
+
+ private:
+};
+
+class TransformTensorOperationParser : public TFLiteOperationParser {
+ public:
+  Status IsSupported(const TfLiteContext* context,
+                     const TfLiteNode* tflite_node,
+                     const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(
+        CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
+    return OkStatus();
+  }
+
+  Status Parse(const TfLiteNode* tflite_node,
+               const TfLiteRegistration* registration, GraphFloat32* graph,
+               ObjectReader* reader) final {
+    Node* node = graph->NewNode();
+    RETURN_IF_ERROR(reader->AddInput(node, 0));  // data
+    RETURN_IF_ERROR(reader->AddInput(node, 1));  // bbox
+    RETURN_IF_ERROR(reader->AddOutputs(node));
+
+    std::string op_name = "transform_tensor";
+    node->operation.type = op_name;
+    BHWC output_shape;
+    RETURN_IF_ERROR(
+        ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
+                              tflite_node->custom_initial_data_size,
+                              &(node->operation.attributes), &output_shape));
+
+    auto output_value = graph->FindOutputs(node->id)[0];
+
+    output_value->tensor.shape =
+        BHWC(1, output_shape.h, output_shape.w,
+             graph->FindInputs(node->id)[0]->tensor.shape.c);
+    return OkStatus();
+  }
+
+ private:
+};
+
+class TransformLandmarksOperationParser : public TFLiteOperationParser {
+ public:
+  Status IsSupported(const TfLiteContext* context,
+                     const TfLiteNode* tflite_node,
+                     const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(
+        CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
+    return OkStatus();
+  }
+
+  Status Parse(const TfLiteNode* tflite_node,
+               const TfLiteRegistration* registration, GraphFloat32* graph,
+               ObjectReader* reader) final {
+    Node* node = graph->NewNode();
+    RETURN_IF_ERROR(reader->AddInput(node, 0));  // data
+    RETURN_IF_ERROR(reader->AddInput(node, 1));  // bbox
+    RETURN_IF_ERROR(reader->AddOutputs(node));
+    std::string op_name = "transform_landmarks";
+    node->operation.type = op_name;
+    BHWC output_shape;
+    RETURN_IF_ERROR(
+        ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
+                              tflite_node->custom_initial_data_size,
+                              &(node->operation.attributes), &output_shape));
+
+    auto output_value = graph->FindOutputs(node->id)[0];
+
+    output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
+    return OkStatus();
+  }
+
+ private:
+};
+
 class UnsupportedOperationParser : public TFLiteOperationParser {
  public:
   Status IsSupported(const TfLiteContext* context,
@@ -2332,6 +2437,18 @@ std::unique_ptr NewOperationParser(
       if (custom_name == "MaxUnpooling2D") {
         return absl::make_unique();
       }
+      if (custom_name == "RoIToTransformMatrix") {
+        return absl::make_unique();
+      }
+
+      if (custom_name == "TransformTensor") {
+        return absl::make_unique();
+      }
+
+      if (custom_name == "TransformLandmarks") {
+        return absl::make_unique();
+      }
+
       break;
   }
   return absl::make_unique();

From ad42cf3ed6a544ca6f14f1e7734db9387c122eb2 Mon Sep 17 00:00:00 2001
From: Katherine Wu 
Date: Tue, 3 Dec 2019 15:19:33 -0800
Subject: [PATCH 248/279] When exporting SavedModel, force functional and
 sequential models to save config even if error occurs when serializing
 layers.

This ensures that the network structure is saved even if a custom layer doesn't define its config.

PiperOrigin-RevId: 283633369
Change-Id: I59c4e7dcb9acca837bc6534af046c7f21663ff24
---
 tensorflow/python/keras/BUILD                 | 12 +++++
 tensorflow/python/keras/engine/sequential.py  | 18 ++++---
 .../saving/saved_model/layer_serialization.py | 23 +++++----
 .../saved_model/model_serialization_test.py   | 48 +++++++++++++++++++
 .../python/keras/utils/generic_utils.py       | 33 ++++++++++---
 5 files changed, 109 insertions(+), 25 deletions(-)
 create mode 100644 tensorflow/python/keras/saving/saved_model/model_serialization_test.py

diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index c3abf49ae59..88b6165c2a1 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -1856,6 +1856,18 @@ tf_py_test(
     ],
 )
 
+tf_py_test(
+    name = "model_serialization_test",
+    size = "medium",
+    srcs = ["saving/saved_model/model_serialization_test.py"],
+    additional_deps = [
+        ":keras",
+        "@absl_py//absl/testing:parameterized",
+        "//tensorflow/python/distribute:mirrored_strategy",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
 tf_py_test(
     name = "saving_utils_test",
     size = "medium",
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 369cd31d656..522aed6aaa4 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -345,13 +345,17 @@ class Sequential(training.Model):
     layer_configs = []
     for layer in self.layers:
       layer_configs.append(generic_utils.serialize_keras_object(layer))
-    # When constructed using an `InputLayer` the first non-input layer may not
-    # have the shape information to reconstruct `Sequential` as a graph network.
-    if (self._is_graph_network and layer_configs and
-        'batch_input_shape' not in layer_configs[0]['config'] and
-        isinstance(self._layers[0], input_layer.InputLayer)):
-      batch_input_shape = self._layers[0]._batch_input_shape
-      layer_configs[0]['config']['batch_input_shape'] = batch_input_shape
+
+    if layer_configs and layer_configs[0]['config'] is not None:
+      # layer_configs[0]['config'] may be None only when saving SavedModel.
+
+      # Check to see whether the first non-input layer has the shape information
+      # to reconstruct `Sequential` as a graph network. If not, add it.
+      if (self._is_graph_network and
+          'batch_input_shape' not in layer_configs[0]['config'] and
+          isinstance(self._layers[0], input_layer.InputLayer)):
+        batch_input_shape = self._layers[0]._batch_input_shape
+        layer_configs[0]['config']['batch_input_shape'] = batch_input_shape
 
     config = {
         'name': self.name,
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index 054a01e1db0..ab1edaab585 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -23,7 +23,7 @@ from tensorflow.python.keras.saving.saved_model import base_serialization
 from tensorflow.python.keras.saving.saved_model import constants
 from tensorflow.python.keras.saving.saved_model import save_impl
 from tensorflow.python.keras.saving.saved_model import serialized_attributes
-from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.util import nest
 
 
@@ -51,23 +51,22 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
         expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
         dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
         batch_input_shape=getattr(self.obj, '_batch_input_shape', None))
-    try:
-      # Store the config dictionary, which is only used by the revived object
-      # to return the original config when revived_obj.get_config() is called.
-      # It is not important for recreating the revived object.
-      metadata['config'] = self.obj.get_config()
-    except NotImplementedError:
-      # in the case of a subclassed model, the get_config() method will throw
-      # a NotImplementedError.
-      pass
+
+    with generic_utils.skip_failed_serialization():
+      # Store the config dictionary, which may be used when reviving the object.
+      # When loading, the program will attempt to revive the object from config,
+      # and if that fails, the object will be revived from the SavedModel.
+      config = generic_utils.serialize_keras_object(self.obj)['config']
+      if config is not None:
+        metadata['config'] = config
     if self.obj.input_spec is not None:
       # Layer's input_spec has already been type-checked in the property setter.
       metadata['input_spec'] = nest.map_structure(
-          lambda x: None if x is None else serialize_keras_object(x),
+          lambda x: generic_utils.serialize_keras_object(x) if x else None,
           self.obj.input_spec)
     if (self.obj.activity_regularizer is not None and
         hasattr(self.obj.activity_regularizer, 'get_config')):
-      metadata['activity_regularizer'] = serialize_keras_object(
+      metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
           self.obj.activity_regularizer)
     return metadata
 
diff --git a/tensorflow/python/keras/saving/saved_model/model_serialization_test.py b/tensorflow/python/keras/saving/saved_model/model_serialization_test.py
new file mode 100644
index 00000000000..125ab2fd958
--- /dev/null
+++ b/tensorflow/python/keras/saving/saved_model/model_serialization_test.py
@@ -0,0 +1,48 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Unit tests for serializing Keras models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import keras
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.platform import test
+
+
+class CustomLayer(keras.layers.Layer):
+
+  def __init__(self, unused_a):
+    super(CustomLayer, self).__init__()
+
+
+class ModelSerializationTest(keras_parameterized.TestCase):
+
+  @keras_parameterized.run_with_all_model_types(exclude_models=['subclass'])
+  def test_model_config_always_saved(self):
+    layer = CustomLayer(None)
+    with self.assertRaisesRegexp(NotImplementedError,
+                                 'must override `get_config`.'):
+      layer.get_config()
+    model = testing_utils.get_model_from_layers([layer], input_shape=(3,))
+    properties = model._trackable_saved_model_saver.python_properties
+    self.assertIsNotNone(properties['config'])
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index 8ff27a38d77..8b899dc0c74 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -30,6 +30,7 @@ import numpy as np
 import six
 
 from tensorflow.python.util import nest
+from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import keras_export
@@ -37,6 +38,11 @@ from tensorflow.python.util.tf_export import keras_export
 _GLOBAL_CUSTOM_OBJECTS = {}
 _GLOBAL_CUSTOM_NAMES = {}
 
+# Flag that determines whether to skip the NotImplementedError when calling
+# get_config in custom models and layers. This is only enabled when saving to
+# SavedModel, when the config isn't required.
+_SKIP_FAILED_SERIALIZATION = False
+
 
 @keras_export('keras.utils.CustomObjectScope')
 class CustomObjectScope(object):
@@ -187,6 +193,17 @@ def _get_name_or_custom_name(obj):
     return obj.__name__
 
 
+@tf_contextlib.contextmanager
+def skip_failed_serialization():
+  global _SKIP_FAILED_SERIALIZATION
+  prev = _SKIP_FAILED_SERIALIZATION
+  try:
+    _SKIP_FAILED_SERIALIZATION = True
+    yield
+  finally:
+    _SKIP_FAILED_SERIALIZATION = prev
+
+
 @keras_export('keras.utils.serialize_keras_object')
 def serialize_keras_object(instance):
   """Serialize Keras object into JSON."""
@@ -195,7 +212,13 @@ def serialize_keras_object(instance):
     return None
 
   if hasattr(instance, 'get_config'):
-    config = instance.get_config()
+    name = _get_name_or_custom_name(instance.__class__)
+    try:
+      config = instance.get_config()
+    except NotImplementedError as e:
+      if _SKIP_FAILED_SERIALIZATION:
+        return serialize_keras_class_and_config(name, None)
+      raise e
     serialization_config = {}
     for key, item in config.items():
       if isinstance(item, six.string_types):
@@ -211,15 +234,13 @@ def serialize_keras_object(instance):
         serialization_config[key] = serialized_item
       except ValueError:
         serialization_config[key] = item
-
-    name = _get_name_or_custom_name(instance.__class__)
     return serialize_keras_class_and_config(name, serialization_config)
   if hasattr(instance, '__name__'):
     return _get_name_or_custom_name(instance)
   raise ValueError('Cannot serialize', instance)
 
 
-def _get_custom_objects_by_name(item, custom_objects=None):
+def get_custom_objects_by_name(item, custom_objects=None):
   """Returns the item if it is in either local or global custom objects."""
   if item in _GLOBAL_CUSTOM_OBJECTS:
     return _GLOBAL_CUSTOM_OBJECTS[item]
@@ -260,7 +281,7 @@ def class_and_config_for_serialized_keras_object(
           printable_module_name='config_item')
     elif (isinstance(item, six.string_types) and
           tf_inspect.isfunction(
-              _get_custom_objects_by_name(item, custom_objects))):
+              get_custom_objects_by_name(item, custom_objects))):
       # Handle custom functions here. When saving functions, we only save the
       # function's name as a string. If we find a matching string in the custom
       # objects during deserialization, we convert the string back to the
@@ -269,7 +290,7 @@ def class_and_config_for_serialized_keras_object(
       # conflict with a custom function name, but this should be a rare case.
       # This issue does not occur if a string field has a naming conflict with
       # a custom object, since the config of an object will always be a dict.
-      deserialized_objects[key] = _get_custom_objects_by_name(
+      deserialized_objects[key] = get_custom_objects_by_name(
           item, custom_objects)
   for key, item in deserialized_objects.items():
     cls_config[key] = deserialized_objects[key]

From 91d0e95d82ca062a3f04d769c92636ebd74c3ff3 Mon Sep 17 00:00:00 2001
From: Dan Moldovan 
Date: Tue, 3 Dec 2019 15:34:02 -0800
Subject: [PATCH 249/279] Cleanup: consistently pass symbol names and loop
 options to all overloads. Slightly refactor code to reduce the size of shape
 verification functions.

PiperOrigin-RevId: 283636317
Change-Id: I828df08f8fec05d62e90985caec16399866e9bdf
---
 .../converters/conditional_expressions.py     |  11 +-
 .../autograph/converters/control_flow.py      |  30 +-
 .../autograph/operators/control_flow.py       | 423 +++++++++---------
 .../autograph/operators/control_flow_test.py  | 141 ++++--
 4 files changed, 355 insertions(+), 250 deletions(-)

diff --git a/tensorflow/python/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py
index 4538b16660c..125ef5375be 100644
--- a/tensorflow/python/autograph/converters/conditional_expressions.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions.py
@@ -27,8 +27,15 @@ class ConditionalExpressionTransformer(converter.Base):
 
   def visit_IfExp(self, node):
     return templates.replace_as_expression(
-        '''ag__.if_stmt(test, lambda: true_expr,
-                        lambda: false_expr, lambda: (), lambda _: None)''',
+        '''ag__.if_stmt(
+            test,
+            lambda: true_expr,
+            lambda: false_expr,
+            lambda: (),
+            lambda _: None,
+            ('',),
+            ())
+        ''',
         test=node.test,
         true_expr=node.body,
         false_expr=node.orelse)
diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index 8f4170281b4..5bf488cd209 100644
--- a/tensorflow/python/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -383,6 +383,9 @@ class ControlFlowTransformer(converter.Base):
     composite_symbol_names = tuple(
         gast.Str(str(symbol)) for symbol in composite_loop_vars)
 
+    # TODO(b/140125096): Populate.
+    opts = gast.Dict([], [])
+
     # TODO(mdan): Use a single template.
     # If the body and test functions took a single tuple for loop_vars, instead
     # of *loop_vars, then a single template could be used.
@@ -401,7 +404,8 @@ class ControlFlowTransformer(converter.Base):
             state_setter_name,
             (loop_vars,),
             (basic_symbol_names,),
-            (composite_symbol_names,))
+            (composite_symbol_names,),
+            opts)
       """
       node = templates.replace(
           template,
@@ -415,7 +419,8 @@ class ControlFlowTransformer(converter.Base):
           state_getter_name=state_getter_name,
           state_setter_name=state_setter_name,
           basic_symbol_names=basic_symbol_names,
-          composite_symbol_names=composite_symbol_names)
+          composite_symbol_names=composite_symbol_names,
+          opts=opts)
     else:
       template = """
         state_functions
@@ -431,7 +436,8 @@ class ControlFlowTransformer(converter.Base):
             state_setter_name,
             (),
             (),
-            (composite_symbol_names,))
+            (composite_symbol_names,),
+            opts)
       """
       node = templates.replace(
           template,
@@ -442,7 +448,8 @@ class ControlFlowTransformer(converter.Base):
           state_functions=state_functions,
           state_getter_name=state_getter_name,
           state_setter_name=state_setter_name,
-          composite_symbol_names=composite_symbol_names)
+          composite_symbol_names=composite_symbol_names,
+          opts=opts)
 
     undefined_assigns = self._create_undefined_assigns(possibly_undefs)
     return undefined_assigns + node
@@ -500,6 +507,9 @@ class ControlFlowTransformer(converter.Base):
     composite_symbol_names = tuple(
         gast.Str(str(symbol)) for symbol in composite_loop_vars)
 
+    # TODO(b/140125096): Populate.
+    opts = gast.Dict([], [])
+
     # TODO(mdan): Use a single template.
     # If the body and test functions took a single tuple for loop_vars, instead
     # of *loop_vars, then a single template could be used.
@@ -520,7 +530,8 @@ class ControlFlowTransformer(converter.Base):
             state_setter_name,
             (loop_vars,),
             (basic_symbol_names,),
-            (composite_symbol_names,))
+            (composite_symbol_names,),
+            opts)
       """
       return templates.replace(
           template,
@@ -538,7 +549,8 @@ class ControlFlowTransformer(converter.Base):
           state_getter_name=state_getter_name,
           state_setter_name=state_setter_name,
           basic_symbol_names=basic_symbol_names,
-          composite_symbol_names=composite_symbol_names)
+          composite_symbol_names=composite_symbol_names,
+          opts=opts)
     else:
       template = """
         undefined_assigns
@@ -556,7 +568,8 @@ class ControlFlowTransformer(converter.Base):
             state_setter_name,
             (),
             (),
-            (composite_symbol_names,))
+            (composite_symbol_names,),
+            opts)
       """
       return templates.replace(
           template,
@@ -571,7 +584,8 @@ class ControlFlowTransformer(converter.Base):
           state_functions=state_functions,
           state_getter_name=state_getter_name,
           state_setter_name=state_setter_name,
-          composite_symbol_names=composite_symbol_names)
+          composite_symbol_names=composite_symbol_names,
+          opts=opts)
 
 
 def transform(node, ctx):
diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
index bbfee424315..c862379e1d0 100644
--- a/tensorflow/python/autograph/operators/control_flow.py
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -109,160 +109,140 @@ def _disallow_undefs_into_loop(*values):
           'return statements are not supported within a TensorFlow loop.')
 
 
-def _shape_greater_than_or_equal(shape1, shape2):
-  """Check whether the shape2 is equal or more specific than shape1."""
-
-  # The following logic was mirrored from control_flow_ops.py's
-  # _ShapeLessThanOrEqual function.
-  if shape1.dims is None:
+def _is_subshape(left, right):
+  """Returns True if left shape is at least as specific as right shape."""
+  # TODO(mdan): This code should be in TensorShape.
+  # Note: this is not the same as TensorShape.is_compatible_with, which is
+  # symmetric.
+  # This code also duplicates _ShapeLessThanOrEqual from  control_flow_ops.py.
+  if right.dims is None:
     return True
-  if shape1.ndims != shape2.ndims:
+  if left.ndims != right.ndims:
     return False
-  for dim1, dim2 in zip(shape1.dims, shape2.dims):
-    if dim1.value is not None and dim1.value != dim2.value:
+  for ldim, rdim in zip(left.dims, right.dims):
+    if rdim.value is not None and ldim.value != rdim.value:
       return False
   return True
 
 
-def _verify_tf_loop_vars(init_loop_vars,
-                         first_iter_vars,
-                         basic_symbol_names,
-                         composite_symbol_names,
-                         include_shapes=True):
-  """Verifies loop variables for consistency."""
+def _verify_single_loop_var(name, check_shape, init_loop_var, first_iter_var):
+  """Verifies whether init_loop_var and first_iter_var are consistent."""
+  if isinstance(init_loop_var, (bool, int, float, str)):
+    init_loop_var = ops.convert_to_tensor_v2(init_loop_var)
 
-  # The whole point of _verify_tf_loop_vars is to give more useful error message
-  # than tf-level exception by including variable names.  If it's not available,
-  # there is no point at performing this verification here.  As of 2019-07-31,
-  # operators:control_flow_test does not pass the names.
-  if basic_symbol_names is None:
+  if isinstance(first_iter_var, (bool, int, float, str)):
+    first_iter_var = ops.convert_to_tensor_v2(first_iter_var)
+
+  if (not tensor_util.is_tensor(init_loop_var) or
+      not tensor_util.is_tensor(first_iter_var)):
     return
 
-  output_symbol_names = basic_symbol_names + composite_symbol_names
+  # TODO(mdan): Properly account for CompositeTensors.
+  if (not hasattr(init_loop_var, 'dtype') or
+      not hasattr(first_iter_var, 'dtype')):
+    return
+  if (not hasattr(init_loop_var, 'shape') or
+      not hasattr(first_iter_var, 'shape')):
+    return
 
-  assert len(init_loop_vars) == len(first_iter_vars) == len(output_symbol_names)
+  if init_loop_var.dtype != first_iter_var.dtype:
+    raise TypeError(
+        '"{}" has dtype {} before the loop, but dtype {} after one'
+        ' iteration. TensorFlow control flow requires it stays the'
+        ' same.'.format(
+            name,
+            init_loop_var.dtype.name,
+            first_iter_var.dtype.name,
+        ))
 
-  for init_loop_var, first_iter_var, name in zip(init_loop_vars,
-                                                 first_iter_vars,
-                                                 output_symbol_names):
+  if check_shape:
+    init_shape = init_loop_var.shape
+    first_iter_shape = first_iter_var.shape
+    # TODO(b/135183013): Update needed once we support shape_invariants.
+    if not _is_subshape(first_iter_shape, init_shape):
+      raise ValueError(
+          '"{}" has shape {} before the loop, but shape {} after one'
+          ' iteration. TensorFlow control flow requires it stays the'
+          ' same or be more specific.'.format(name, init_shape,
+                                              first_iter_shape))
 
+
+def _verify_tf_loop_vars(init_loop_vars,
+                         first_iter_vars,
+                         symbol_names,
+                         opts,
+                         check_shapes=True):
+  """Verifies loop variables for consistency."""
+  # TODO(b/140125096): Use this.
+  del opts
+
+  named_vars = zip(symbol_names, init_loop_vars, first_iter_vars)
+  for name, init_loop_var, first_iter_var in named_vars:
     try:
       nest.assert_same_structure(
           init_loop_var, first_iter_var, expand_composites=True)
     except (ValueError, TypeError) as e:
       raise TypeError('"{}" does not have the same nested structure after one'
                       ' iteration.\n\n{}'.format(name, e))
-
-    def _check_same_type(name, init_loop_var, first_iter_var):
-      """Ensures init_loop_var and first_iter_var are consistent."""
-      if isinstance(init_loop_var, (bool, int, float, str)):
-        init_loop_var = ops.convert_to_tensor_v2(init_loop_var)
-
-      if isinstance(first_iter_var, (bool, int, float, str)):
-        first_iter_var = ops.convert_to_tensor_v2(first_iter_var)
-
-      if (not tensor_util.is_tensor(init_loop_var) or
-          not tensor_util.is_tensor(first_iter_var)):
-        return
-
-      # TODO(mdan): Properly account for CompositeTensors.
-      if (not hasattr(init_loop_var, 'dtype') or
-          not hasattr(first_iter_var, 'dtype')):
-        return
-      if (not hasattr(init_loop_var, 'shape') or
-          not hasattr(first_iter_var, 'shape')):
-        return
-
-      if init_loop_var.dtype != first_iter_var.dtype:
-        raise TypeError(
-            '"{}" has dtype {} before the loop, but dtype {} after one'
-            ' iteration. TensorFlow control flow requires it stays the'
-            ' same.'.format(
-                name,
-                init_loop_var.dtype.name,
-                first_iter_var.dtype.name,
-            ))
-
-      if include_shapes:
-        init_shape = init_loop_var.shape
-        first_iter_shape = first_iter_var.shape
-        # TODO(b/135183013): Update needed once we support shape_invariants.
-        if not _shape_greater_than_or_equal(init_shape, first_iter_shape):
-          raise ValueError(
-              '"{}" has shape {} before the loop, but shape {} after one'
-              ' iteration. TensorFlow control flow requires it stays the'
-              ' same or be more specific.'.format(name, init_shape,
-                                                  first_iter_shape))
-
     nest.map_structure(
-        functools.partial(_check_same_type, name), init_loop_var,
-        first_iter_var)
+        functools.partial(_verify_single_loop_var, name, check_shapes),
+        init_loop_var, first_iter_var)
 
 
-def _verify_tf_cond_vars(body_outputs, orelse_outputs, basic_symbol_names,
-                         composite_symbol_names):
-  """Verifies variables manipulated by a conditional for consistency."""
+def _verify_single_cond_var(name, body_var, orelse_var):
+  """Verifies whether body_var and orelse_var are consistent."""
+  if isinstance(body_var, (bool, int, float, str)):
+    body_var = ops.convert_to_tensor_v2(body_var)
 
-  # The whole point of _verify_tf_cond_vars is to give more useful error message
-  # than tf-level exception by including variable names.  If it's not available,
-  # there is no point at performing this verification here.  As of 2019-07-31,
-  # conditional expression does not pass the names.
-  if basic_symbol_names is None:
+  if isinstance(orelse_var, (bool, int, float, str)):
+    orelse_var = ops.convert_to_tensor_v2(orelse_var)
+
+  if (not tensor_util.is_tensor(body_var) or
+      not tensor_util.is_tensor(orelse_var)):
     return
 
-  output_symbol_names = basic_symbol_names + composite_symbol_names
+  # TODO(mdan): Properly account for CompositeTensors.
+  if (not hasattr(body_var, 'dtype') or
+      not hasattr(orelse_var, 'dtype')):
+    return
 
-  basic_body_outputs, composite_body_outputs = body_outputs
-  basic_orelse_outputs, composite_orelse_outputs = orelse_outputs
-  assert isinstance(composite_body_outputs, tuple)
-  assert isinstance(composite_orelse_outputs, tuple)
+  if body_var.dtype != orelse_var.dtype:
+    raise TypeError(
+        '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE'
+        ' branch. TensorFlow control flow requires that they are the'
+        ' same.'.format(name, body_var.dtype.name,
+                        orelse_var.dtype.name))
+
+
+def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names):
+  """Verifies variables manipulated by a conditional for consistency."""
+  basic_body_vars, composite_body_vars = body_vars
+  basic_orelse_vars, composite_orelse_vars = orelse_vars
+  assert isinstance(composite_body_vars, tuple)
+  assert isinstance(composite_orelse_vars, tuple)
 
   # TODO(kkimlabs): Make this more consistent.
   # The basic outputs should always be a tuple.
-  if not isinstance(basic_body_outputs, tuple):
-    basic_body_outputs = (basic_body_outputs,)
-  if not isinstance(basic_orelse_outputs, tuple):
-    basic_orelse_outputs = (basic_orelse_outputs,)
+  if not isinstance(basic_body_vars, tuple):
+    basic_body_vars = (basic_body_vars,)
+  if not isinstance(basic_orelse_vars, tuple):
+    basic_orelse_vars = (basic_orelse_vars,)
 
-  body_outputs = basic_body_outputs + composite_body_outputs
-  orelse_outputs = basic_orelse_outputs + composite_orelse_outputs
+  body_vars = basic_body_vars + composite_body_vars
+  orelse_vars = basic_orelse_vars + composite_orelse_vars
 
-  for body_output, orelse_output, name in zip(body_outputs, orelse_outputs,
-                                              output_symbol_names):
+  named_vars = zip(symbol_names, body_vars, orelse_vars)
+  for name, body_var, orelse_var in named_vars:
     try:
       nest.assert_same_structure(
-          body_output, orelse_output, expand_composites=True)
+          body_var, orelse_var, expand_composites=True)
     except (ValueError, TypeError) as e:
       raise TypeError(
           '"{}" does not have the same nested structure in the TRUE and FALSE'
           ' branches.\n\n{}'.format(name, str(e)))
 
-    def _check_same_type(name, body_output_var, orelse_output_var):
-      """Verfies that body_output_var and orelse_output_var have same dtype."""
-      if isinstance(body_output_var, (bool, int, float, str)):
-        body_output_var = ops.convert_to_tensor_v2(body_output_var)
-
-      if isinstance(orelse_output_var, (bool, int, float, str)):
-        orelse_output_var = ops.convert_to_tensor_v2(orelse_output_var)
-
-      if (not tensor_util.is_tensor(body_output_var) or
-          not tensor_util.is_tensor(orelse_output_var)):
-        return
-
-      # TODO(mdan): Properly account for CompositeTensors.
-      if (not hasattr(body_output_var, 'dtype') or
-          not hasattr(orelse_output_var, 'dtype')):
-        return
-
-      if body_output_var.dtype != orelse_output_var.dtype:
-        raise TypeError(
-            '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE'
-            ' branch. TensorFlow control flow requires that they are the'
-            ' same.'.format(name, body_output_var.dtype.name,
-                            orelse_output_var.dtype.name))
-
     nest.map_structure(
-        functools.partial(_check_same_type, name), body_output, orelse_output)
+        functools.partial(_verify_single_cond_var, name), body_var, orelse_var)
 
 
 def for_stmt(iter_,
@@ -271,8 +251,9 @@ def for_stmt(iter_,
              get_state,
              set_state,
              init_vars,
-             basic_symbol_names=None,
-             composite_symbol_names=None):
+             basic_symbol_names,
+             composite_symbol_names,
+             opts):
   """Functional form of a for statement.
 
   The loop operates on a state, which includes all symbols that are
@@ -308,6 +289,7 @@ def for_stmt(iter_,
     init_vars: Tuple containing the initial state.
     basic_symbol_names: Tuple containing basic loop var names.
     composite_symbol_names: Tuple containing composite loop var names.
+    opts: Optional dict of extra loop parameters.
 
   Returns:
     Tuple containing the final state.
@@ -316,26 +298,26 @@ def for_stmt(iter_,
     if tensors.is_range_tensor(iter_):
       return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                                 init_vars, basic_symbol_names,
-                                composite_symbol_names)
+                                composite_symbol_names, opts)
     else:
       return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
                                     set_state, init_vars, basic_symbol_names,
-                                    composite_symbol_names)
+                                    composite_symbol_names, opts)
 
   if isinstance(iter_, dataset_ops.DatasetV2):
     return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
                                 init_vars, basic_symbol_names,
-                                composite_symbol_names)
+                                composite_symbol_names, opts)
 
   if isinstance(iter_, iterator_ops.OwnedIterator):
     return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
                                  init_vars, basic_symbol_names,
-                                 composite_symbol_names)
+                                 composite_symbol_names, opts)
 
   if isinstance(iter_, ragged_tensor.RaggedTensor):
     return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
                                init_vars, basic_symbol_names,
-                               composite_symbol_names)
+                               composite_symbol_names, opts)
 
   # Note: This experimental interface is subject to change.
   custom_handler = getattr(iter_, '_autograph_for_loop', None)
@@ -360,9 +342,15 @@ def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
   return state
 
 
-def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
-                           init_vars, basic_symbol_names,
-                           composite_symbol_names):
+def _known_len_tf_for_stmt(iter_,
+                           extra_test,
+                           body,
+                           get_state,
+                           set_state,
+                           init_vars,
+                           basic_symbol_names,
+                           composite_symbol_names,
+                           opts):
   """Overload of for_stmt that iterates over TF entities that admit a length."""
   _disallow_undefs_into_loop(*init_vars)
 
@@ -377,8 +365,6 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
     """Main loop body."""
     iterate = iter_.read(iterate_index)
     new_vars = body(iterate, *loop_vars)
-    _verify_tf_loop_vars(loop_vars, new_vars, basic_symbol_names,
-                         composite_symbol_names)
 
     loop_vars = (iterate_index + 1,)
     if new_vars:
@@ -388,13 +374,12 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
 
   def while_cond(iterate_index, *loop_vars):
     if extra_test is not None:
-      return control_flow_ops.cond(
-          iterate_index < n, lambda: extra_test(*loop_vars), lambda: False)
+      return control_flow_ops.cond(iterate_index < n,
+                                   lambda: extra_test(*loop_vars),
+                                   lambda: False)
     return iterate_index < n
 
-  opts = {}
-  # TODO(b/134181679): We do not always set maximum_iterations since that
-  # is significantly slower on GPU.
+  # TODO(b/134181679): Let the op itself handle optimizations.
   if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
     opts['maximum_iterations'] = n
 
@@ -403,10 +388,10 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
       while_body,
       get_state,
       set_state,
-      (0,) + init_vars,
-      None,
-      None,
-      opts=opts,
+      (array_ops.zeros_like(n),) + init_vars,
+      ('',) + basic_symbol_names,
+      composite_symbol_names,
+      opts,
   )
 
   # Note: the iteration index is not returned by the while loop, however
@@ -422,9 +407,15 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
   return results
 
 
-def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
-                        init_vars, basic_symbol_names,
-                        composite_symbol_names):
+def _tf_ragged_for_stmt(iter_,
+                        extra_test,
+                        body,
+                        get_state,
+                        set_state,
+                        init_vars,
+                        basic_symbol_names,
+                        composite_symbol_names,
+                        opts):
   """Overload of for_stmt that iterates over TF ragged tensors."""
   _disallow_undefs_into_loop(*init_vars)
 
@@ -438,8 +429,6 @@ def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
     """Main loop body."""
     iterate = iter_[iterate_index]
     new_vars = body(iterate, *loop_vars)
-    _verify_tf_loop_vars(loop_vars, new_vars, basic_symbol_names,
-                         composite_symbol_names)
 
     loop_vars = (iterate_index + 1,)
     if new_vars:
@@ -450,10 +439,13 @@ def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
   def while_cond(iterate_index, *loop_vars):
     if extra_test is not None:
       return control_flow_ops.cond(
-          iterate_index < n, lambda: extra_test(*loop_vars), lambda: False)
+          iterate_index < n,
+          lambda: extra_test(*loop_vars),
+          lambda: False,
+      )
     return iterate_index < n
 
-  opts = {'maximum_iterations': n}
+  opts['maximum_iterations'] = n
 
   results = _tf_while_stmt(
       while_cond,
@@ -461,9 +453,9 @@ def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
       get_state,
       set_state,
       (array_ops.zeros_like(n),) + init_vars,
-      None,
-      None,
-      opts=opts,
+      ('',) + basic_symbol_names,
+      composite_symbol_names,
+      opts,
   )
 
   if isinstance(results, (tuple, list)):
@@ -476,8 +468,15 @@ def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
   return results
 
 
-def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars,
-                       basic_symbol_names, composite_symbol_names):
+def _tf_range_for_stmt(iter_,
+                       extra_test,
+                       body,
+                       get_state,
+                       set_state,
+                       init_vars,
+                       basic_symbol_names,
+                       composite_symbol_names,
+                       opts):
   """Overload of for_stmt that iterates over a TF range (and elides it)."""
   _disallow_undefs_into_loop(*init_vars)
 
@@ -497,8 +496,9 @@ def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars,
 
     def build_main_test():
       """Main iteration condition."""
-      # Note(b/138857806): LogicalAnd is slow on GPU so we avoid adding it if
-      # `delta` is a compile time constant.
+      # TODO(b/138857806): The optimizer should handle this.
+      # LogicalAnd is slow on GPU so we avoid adding it if `delta` is a
+      # compile time constant.
       delta_const = tensor_util.constant_value(delta)
       if delta_const is not None:
         # Support single element arrays.
@@ -515,16 +515,13 @@ def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars,
     main_test = build_main_test()
     if extra_test is not None:
       return control_flow_ops.cond(
-          main_test, lambda: extra_test(*loop_vars), lambda: False)
+          main_test,
+          lambda: extra_test(*loop_vars),
+          lambda: False,
+      )
     return main_test
 
-  # The first loopvar corresponds to the iterate variable which is internal.
-  if isinstance(basic_symbol_names, tuple):
-    basic_symbol_names = (None,) + basic_symbol_names
-
-  opts = {}
-  # TODO(b/134181679): We do not always set maximum_iterations since that
-  # is significantly slower on GPU.
+  # TODO(b/134181679): The op should handle this optimizations.
   if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
     # This specific dtype is required by while_loop.
     opts['maximum_iterations'] = math_ops.cast(
@@ -536,9 +533,9 @@ def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars,
       get_state,
       set_state,
       (start,) + init_vars,
-      basic_symbol_names,
+      ('',) + basic_symbol_names,
       composite_symbol_names,
-      opts=opts,
+      opts,
   )
 
   # Note: the iteration index is not returned by the while loop, however
@@ -556,21 +553,24 @@ def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars,
 
 def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
                           init_vars, basic_symbol_names,
-                          composite_symbol_names):
+                          composite_symbol_names, opts):
   """Overload of for_stmt that iterates over TF Iterators. See for_loop."""
   _disallow_undefs_into_loop(*init_vars)
 
   def while_body_actual(opt_iterate, *loop_vars):
     """Actual main loop body."""
     new_vars = body(opt_iterate.get_value(), *loop_vars)
-    _verify_tf_loop_vars(loop_vars, new_vars, basic_symbol_names,
-                         composite_symbol_names)
     # TODO(mdan): Fix this inconsistency in the converter.
     if new_vars is None:
       new_vars = ()
+    # Note: this verification duplicates that perfrmed in tf_while_stmt,
+    # but needs to be done earlier to prevent the tf.cond inside while_body
+    # from blowing up first.
+    _verify_tf_loop_vars(loop_vars, new_vars,
+                         basic_symbol_names + composite_symbol_names, opts)
     return new_vars
 
-  def while_body(has_next, loop_vars):
+  def while_body(has_next, *loop_vars):
     """Main loop body."""
     opt_iterate = iterator_ops.get_next_as_optional(itr)
     has_next = opt_iterate.has_value()
@@ -591,30 +591,32 @@ def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
     if dummy_state:
       new_vars = new_vars[1:]
 
-    return has_next, new_vars
+    return (has_next,) + new_vars
 
-  def while_cond(has_next, loop_vars):
+  def while_cond(has_next, *loop_vars):
     if extra_test is not None:
       return control_flow_ops.cond(
-          has_next, lambda: extra_test(*loop_vars), lambda: False)
+          has_next,
+          lambda: extra_test(*loop_vars),
+          lambda: False,
+      )
     return has_next
 
-  # The first loopvar corresponds to the iterate variable which is internal.
-  _, final_vars = _tf_while_stmt(
+  final_vars = _tf_while_stmt(
       while_cond,
       while_body,
       get_state,
       set_state,
-      (True, init_vars),
-      None,
-      None,
-      opts=None,
+      (True,) + init_vars,
+      ('',) + basic_symbol_names,
+      composite_symbol_names,
+      opts,
   )
-  return final_vars
+  return final_vars[1:]
 
 
 def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars,
-                         basic_symbol_names, composite_symbol_names):
+                         basic_symbol_names, composite_symbol_names, opts):
   """Overload of for_stmt that iterates over TF Datasets."""
   _disallow_undefs_into_loop(*init_vars)
 
@@ -623,11 +625,11 @@ def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars,
     return _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
                                              set_state, init_vars,
                                              basic_symbol_names,
-                                             composite_symbol_names)
+                                             composite_symbol_names, opts)
 
   return _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state,
                                          init_vars, basic_symbol_names,
-                                         composite_symbol_names)
+                                         composite_symbol_names, opts)
 
 
 def _general_purpose_scan(ds, init_state, body):
@@ -646,7 +648,7 @@ def _general_purpose_scan(ds, init_state, body):
 
 def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
                                       set_state, init_vars, basic_symbol_names,
-                                      composite_symbol_names):
+                                      composite_symbol_names, opts):
   """Overload of _dataset_for_stmt with early stopping. See for_stmt."""
 
   # TODO(mdan): Simplify this - following it is extremely difficult.
@@ -661,14 +663,17 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
       _verify_tf_loop_vars(
           loop_vars + state,
           outputs + state,
-          basic_symbol_names,
-          composite_symbol_names,
-          include_shapes=False)
+          basic_symbol_names + composite_symbol_names,
+          opts,
+          check_shapes=False)
       return outputs, get_state()
 
     extra_cond = extra_test(*loop_vars)
     new_vars, new_state = control_flow_ops.cond(
-        extra_cond, true_fn, lambda: (loop_vars, state))
+        extra_cond,
+        true_fn,
+        lambda: (loop_vars, state),
+    )
 
     scan_outputs = new_vars, new_state, extra_cond
     # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
@@ -696,12 +701,15 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
 
 
 def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
-                                    basic_symbol_names, composite_symbol_names):
+                                    basic_symbol_names, composite_symbol_names,
+                                    opts):
   """Overload of _dataset_for_stmt without early stopping. See for_stmt."""
   init_state = get_state()
   assert isinstance(init_vars, tuple)
   assert isinstance(init_state, tuple)
 
+  symbol_names = basic_symbol_names + composite_symbol_names
+
   # Workaround for Dataset.reduce not allowing empty state tensors - create
   # a dummy state variable that remains unused.
   # TODO(mdan): reduce should allow and match empty structures.
@@ -710,10 +718,10 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
 
   if no_vars:
     init_vars = (constant_op.constant(0),)
-    if isinstance(basic_symbol_names, tuple):
-      basic_symbol_names = (None,) + basic_symbol_names
+    symbol_names = ('',) + symbol_names
   if no_state:
     init_state = (constant_op.constant(0),)
+    symbol_names = symbol_names + ('',)
 
   def scan_body(aug_vars, iterate):
     """The main loop body wrapper."""
@@ -735,9 +743,9 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
     _verify_tf_loop_vars(
         loop_vars + state,
         new_vars + new_state,
-        basic_symbol_names,
-        composite_symbol_names,
-        include_shapes=False)
+        symbol_names,
+        opts,
+        check_shapes=False)
 
     scan_outputs = new_vars, new_state
     # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
@@ -760,16 +768,14 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
   return final_vars
 
 
-def while_stmt(
-    test,
-    body,
-    get_state,
-    set_state,
-    init_vars,
-    basic_symbol_names=None,
-    composite_symbol_names=None,
-    opts=None,
-):
+def while_stmt(test,
+               body,
+               get_state,
+               set_state,
+               init_vars,
+               basic_symbol_names,
+               composite_symbol_names,
+               opts):
   """Functional form of a while statement.
 
   The loop operates on a so-called state, which includes all symbols that are
@@ -818,17 +824,11 @@ def while_stmt(
   return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
 
 
-# TODO(kkimlabs): Some callers set basic_symbol_names=None and
-# composite_symbol_names=None and call _verify_tf_loop_vars(...) itself.  We can
-# remove these arguments once all callers do that.
 def _tf_while_stmt(test, body, get_state, set_state, init_vars,
                    basic_symbol_names, composite_symbol_names, opts):
   """Overload of while_stmt that stages a TF while_stmt."""
   _disallow_undefs_into_loop(*init_vars)
 
-  if opts is None:
-    opts = {}
-
   # TODO(mdan): Simplify this.
   loop_vars_slice = slice(len(init_vars))
   state_slice = slice(len(init_vars), None)
@@ -839,12 +839,13 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
     return test(*aug_loop_vars[loop_vars_slice])
 
   def aug_body(*aug_loop_vars):
+    """Main loop body."""
     state = aug_loop_vars[state_slice]
     set_state(state)
     loop_vars = body(*aug_loop_vars[loop_vars_slice])
     new_state = loop_vars + get_state()
-    _verify_tf_loop_vars(aug_loop_vars, new_state, basic_symbol_names,
-                         composite_symbol_names)
+    _verify_tf_loop_vars(aug_loop_vars, new_state,
+                         basic_symbol_names + composite_symbol_names, opts)
 
     return new_state
 
@@ -948,8 +949,8 @@ def if_stmt(cond,
             orelse,
             get_state,
             set_state,
-            basic_symbol_names=None,
-            composite_symbol_names=None):
+            basic_symbol_names,
+            composite_symbol_names):
   """Functional form of an if statement.
 
   Args:
@@ -1005,14 +1006,14 @@ def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names,
     result[body_branch] = body()
     if result[orelse_branch] is not None:
       _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
-                           basic_symbol_names, composite_symbol_names)
+                           basic_symbol_names + composite_symbol_names)
     return result[body_branch]
 
   def error_checking_orelse():
     result[orelse_branch] = orelse()
     if result[body_branch] is not None:
       _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
-                           basic_symbol_names, composite_symbol_names)
+                           basic_symbol_names + composite_symbol_names)
     return result[orelse_branch]
 
   final_vars, final_state = control_flow_ops.cond(cond, error_checking_body,
diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py
index 2290d61c6fd..a85d74246a1 100644
--- a/tensorflow/python/autograph/operators/control_flow_test.py
+++ b/tensorflow/python/autograph/operators/control_flow_test.py
@@ -50,7 +50,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 10 + i,),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (1234,))
 
   def test_range_tensor(self):
@@ -60,7 +63,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 10 + i,),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (1234,))
 
   def test_range_tensor_random_delta(self):
@@ -71,7 +77,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 10 + i,),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (1234,))
 
   def test_range_tensor_explicit_limit_delta(self):
@@ -81,7 +90,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 100 + i,),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (-171207,))
 
   def test_range_tensor_random_negative_delta(self):
@@ -92,7 +104,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 100 + i,),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (171207,))
 
   def test_range_tensor_negative_delta(self):
@@ -102,7 +117,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 100 + i,),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (171207,))
 
   def test_tensor_with_extra_test_only_python_state(self):
@@ -128,7 +146,10 @@ class ForLoopTest(test.TestCase):
         extra_test=lambda: state.field_1 < 6,
         get_state=get_state,
         set_state=set_state,
-        init_vars=())
+        init_vars=(),
+        basic_symbol_names=(),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(state.field_1), 6)
     self.assertEqual(self.evaluate(state.field_2), 6)
 
@@ -139,7 +160,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 10 + i,),
         get_state=None,
         set_state=None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(s, (1234,))
 
   def test_tf_dataset(self):
@@ -149,7 +173,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 10 + i,),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
+        init_vars=(constant_op.constant(0, dtype=dtypes.int64),),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (1234,))
 
   def test_dataset_with_extra_test(self):
@@ -159,7 +186,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s + i,),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
+        init_vars=(constant_op.constant(0, dtype=dtypes.int64),),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (3,))
 
   def test_dataset_with_extra_test_and_state(self):
@@ -181,7 +211,10 @@ class ForLoopTest(test.TestCase):
         body=body,
         get_state=get_state,
         set_state=set_state,
-        init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
+        init_vars=(constant_op.constant(0, dtype=dtypes.int64),),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (3,))
     self.assertEqual(self.evaluate(state[0]), (3,))
 
@@ -197,7 +230,10 @@ class ForLoopTest(test.TestCase):
         body=guarded_body,
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
+        init_vars=(constant_op.constant(0, dtype=dtypes.int64),),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (3,))
 
   def test_tf_dataset_no_loop_vars(self):
@@ -217,7 +253,10 @@ class ForLoopTest(test.TestCase):
           body=stateless_with_side_effects,
           get_state=lambda: (),
           set_state=lambda _: None,
-          init_vars=())
+          init_vars=(),
+          basic_symbol_names=('i',),
+          composite_symbol_names=(),
+          opts={})
 
     self.evaluate(test_fn())
     self.assertEqual(self.evaluate(v.read_value()), 1234)
@@ -233,7 +272,10 @@ class ForLoopTest(test.TestCase):
           body=lambda i, s: (s * 10 + i,),
           get_state=lambda: (),
           set_state=lambda _: None,
-          init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
+          init_vars=(constant_op.constant(0, dtype=dtypes.int64),),
+          basic_symbol_names=('s',),
+          composite_symbol_names=(),
+          opts={})
     s, = test_fn()
     self.assertAllEqual(s, 1234)
 
@@ -253,7 +295,10 @@ class ForLoopTest(test.TestCase):
           body=stateless_with_side_effects,
           get_state=lambda: (),
           set_state=lambda _: None,
-          init_vars=())
+          init_vars=(),
+          basic_symbol_names=('i',),
+          composite_symbol_names=(),
+          opts={})
 
     self.evaluate(test_fn())
     self.assertEqual(self.evaluate(v.read_value()), 1234)
@@ -265,7 +310,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 10 + i[0],),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (123,))
 
   def test_tf_ragged_tensor_higher_dimensional(self):
@@ -279,7 +327,10 @@ class ForLoopTest(test.TestCase):
         body=lambda i, s: (s * 10 + i[0][0],),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0,))
+        init_vars=(0,),
+        basic_symbol_names=('s',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (12,))
 
   def test_tf_ragged_tensor_no_loop_vars(self):
@@ -298,7 +349,10 @@ class ForLoopTest(test.TestCase):
           body=stateless_with_side_effects,
           get_state=lambda: (),
           set_state=lambda _: None,
-          init_vars=())
+          init_vars=(),
+          basic_symbol_names=(),
+          composite_symbol_names=(),
+          opts={})
 
     self.evaluate(test_fn())
     # Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row).
@@ -315,7 +369,10 @@ class WhileLoopTest(test.TestCase):
         body=lambda i, s: (i + 1, s + i),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0, 0))
+        init_vars=(0, 0),
+        basic_symbol_names=('i', 's'),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual((5, 10), self.evaluate(results))
 
   def test_tensor_with_tf_side_effects_in_cond(self):
@@ -334,7 +391,10 @@ class WhileLoopTest(test.TestCase):
           body=lambda i: (i + 1,),
           get_state=lambda: (),
           set_state=lambda _: None,
-          init_vars=(0,))
+          init_vars=(0,),
+          basic_symbol_names=('i',),
+          composite_symbol_names=(),
+          opts={})
 
     results = test_fn()
 
@@ -364,7 +424,10 @@ class WhileLoopTest(test.TestCase):
         body=body,
         get_state=get_state,
         set_state=set_state,
-        init_vars=(0, 0))
+        init_vars=(0, 0),
+        basic_symbol_names=('i',),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual(self.evaluate(s), (5, 10))
     self.assertEqual(self.evaluate(state.field), 10)
 
@@ -375,7 +438,10 @@ class WhileLoopTest(test.TestCase):
         body=lambda i, s: (i + 1, s + i),
         get_state=lambda: (),
         set_state=lambda _: None,
-        init_vars=(0, constant_op.constant(0)))
+        init_vars=(0, constant_op.constant(0)),
+        basic_symbol_names=('i', 's'),
+        composite_symbol_names=(),
+        opts={})
     result_i, result_s = results
     self.assertEqual(5, result_i)
     self.assertEqual(10, self.evaluate(result_s))
@@ -387,7 +453,10 @@ class WhileLoopTest(test.TestCase):
         body=lambda i, s: (i + 1, s + i),
         get_state=None,
         set_state=None,
-        init_vars=(0, 0))
+        init_vars=(0, 0),
+        basic_symbol_names=('i', 's'),
+        composite_symbol_names=(),
+        opts={})
     self.assertEqual((5, 10), results)
 
   def test_python_infinite_loop(self):
@@ -399,7 +468,10 @@ class WhileLoopTest(test.TestCase):
               body=lambda i: (i + 1,),
               get_state=None,
               set_state=None,
-              init_vars=(0,))
+              init_vars=(0,),
+              basic_symbol_names=('i',),
+              composite_symbol_names=(),
+              opts={})
 
   def test_python_long_loop_unroll_warning(self):
     if __debug__:
@@ -415,7 +487,10 @@ class WhileLoopTest(test.TestCase):
                 body=lambda i, _: (i + 1, gen_math_ops.add(i, 1),),
                 get_state=None,
                 set_state=None,
-                init_vars=(0, None))
+                init_vars=(0, None),
+                basic_symbol_names=('i',),
+                composite_symbol_names=(),
+                opts={})
           self.assertTrue(re.match(
               r'.*ops.*loop.*large.*iterations.*Add.*',
               out_capturer.getvalue()))
@@ -432,7 +507,9 @@ class IfStmtTest(test.TestCase):
           body=lambda: constant_op.constant(1),
           orelse=lambda: constant_op.constant(-1),
           get_state=lambda: (),
-          set_state=lambda _: None)
+          set_state=lambda _: None,
+          basic_symbol_names=('_',),
+          composite_symbol_names=())
 
     self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
     self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
@@ -445,7 +522,9 @@ class IfStmtTest(test.TestCase):
           body=lambda: (constant_op.constant(1), constant_op.constant(2)),
           orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)),
           get_state=lambda: (),
-          set_state=lambda _: None)
+          set_state=lambda _: None,
+          basic_symbol_names=('_',),
+          composite_symbol_names=())
 
     self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
     self.assertEqual((-1, -2),
@@ -459,7 +538,9 @@ class IfStmtTest(test.TestCase):
           body=lambda: 1,
           orelse=lambda: -1,
           get_state=lambda: (),
-          set_state=lambda _: None)
+          set_state=lambda _: None,
+          basic_symbol_names=('_',),
+          composite_symbol_names=())
 
     self.assertEqual(1, test_fn(True))
     self.assertEqual(-1, test_fn(False))
@@ -472,7 +553,9 @@ class IfStmtTest(test.TestCase):
           body=lambda: (1, 2),
           orelse=lambda: (-1, -2),
           get_state=lambda: (),
-          set_state=lambda _: None)
+          set_state=lambda _: None,
+          basic_symbol_names=('_',),
+          composite_symbol_names=())
 
     self.assertEqual((1, 2), test_fn(True))
     self.assertEqual((-1, -2), test_fn(False))

From 01353dbe4b11737d47a0b25db0aedd725dfac241 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 15:53:46 -0800
Subject: [PATCH 250/279] Modify custom_training_loop_test to modify dataset
 rather than simply pass through.

PiperOrigin-RevId: 283640270
Change-Id: Ib58de35c359452cdc290fcc30193f01f8b116b37
---
 .../distribute/custom_training_loop_test.py      | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py
index 1db9bff21f0..55c2ae6a1ca 100644
--- a/tensorflow/python/distribute/custom_training_loop_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_test.py
@@ -43,7 +43,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
     dataset = self._get_dataset()
 
     def train_step(data):
-      return data
+      return math_ops.square(data)
 
     dist_dataset = distribution.experimental_distribute_dataset(dataset)
     results = []
@@ -63,7 +63,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
 
     @def_function.function
     def train_step(data):
-      return data
+      return math_ops.square(data)
 
     dist_dataset = distribution.experimental_distribute_dataset(dataset)
     results = []
@@ -82,7 +82,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
     dataset = self._get_dataset()
 
     def train_step(data):
-      return data
+      return math_ops.square(data)
 
     @def_function.function
     def f_train_step(input_data):
@@ -105,9 +105,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
     dataset = self._get_dataset()
 
     def train_step(data):
-      if math_ops.reduce_sum(data) < 0:
-        return -data
-      return data
+      return math_ops.square(data)
 
     @def_function.function
     def f_train_step(input_data):
@@ -171,7 +169,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
   def testIterationInsideFunction(self, distribution):
 
     def step_fn(data):
-      return data
+      return math_ops.square(data)
 
     @def_function.function
     def train(dataset):
@@ -199,7 +197,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
   def testIterationOutsideFunction(self, distribution):
 
     def train_step(data):
-      return data
+      return math_ops.square(data)
 
     @def_function.function
     def f_train_step(input_data):
@@ -226,7 +224,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
         map(lambda x: math_ops.cast(x, dtypes.int32)).batch(2)
 
   def _validate_outputs(self, actual_results):
-    expected_results = [[i, i+1] for i in range(0, 10, 2)]
+    expected_results = [[i**2, (i+1)**2] for i in range(0, 10, 2)]
     self.assertEqual(len(expected_results), len(actual_results))
 
     for i, expected_result in enumerate(expected_results):

From b02e0bdb0ec3be6571c84f81691ab956e045b005 Mon Sep 17 00:00:00 2001
From: Rick Chao 
Date: Tue, 3 Dec 2019 15:57:00 -0800
Subject: [PATCH 251/279] Migrate tensorflow/python/keras:data_utils_test to
 PY3.

PiperOrigin-RevId: 283640942
Change-Id: Ibdfb4827127b8ca9d7ac815e720da7d45a80f726
---
 tensorflow/python/keras/BUILD                    | 1 +
 tensorflow/python/keras/utils/data_utils_test.py | 1 +
 2 files changed, 2 insertions(+)

diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 88b6165c2a1..069586d3d59 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -1187,6 +1187,7 @@ tf_py_test(
         "//third_party/py/numpy",
         "//tensorflow/python:client_testlib",
     ],
+    python_version = "PY3",
     shard_count = 6,
     tags = [
         "noasan",  # times out
diff --git a/tensorflow/python/keras/utils/data_utils_test.py b/tensorflow/python/keras/utils/data_utils_test.py
index 0d3854890c5..e10d8064401 100644
--- a/tensorflow/python/keras/utils/data_utils_test.py
+++ b/tensorflow/python/keras/utils/data_utils_test.py
@@ -241,6 +241,7 @@ class TestEnqueuers(test.TestCase):
     # One epoch is completed so enqueuer will switch the Sequence
 
     acc = []
+    self.skipTest('b/145555807 flakily timing out.')
     for _ in range(100):
       acc.append(next(gen_output2)[0, 0, 0, 0])
     self.assertEqual(acc[-1], 99 * 15)

From bd08a224302ca83f028f5527e53c597cfce1a6a7 Mon Sep 17 00:00:00 2001
From: Yu-Cheng Ling 
Date: Tue, 3 Dec 2019 15:57:11 -0800
Subject: [PATCH 252/279] tflite_convert: Always propagate the
 experimental_new_converter flag.

This is a non-functional change for now because the new converter is disblaed
by default. The change is for smoothening the flow when we enable it in the future.

PiperOrigin-RevId: 283640972
Change-Id: I9dd1c9b60cf631f8bb1b3913627963b195aa7934
---
 tensorflow/lite/python/tflite_convert.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py
index 02a00ea79b6..59e43be807a 100644
--- a/tensorflow/lite/python/tflite_convert.py
+++ b/tensorflow/lite/python/tflite_convert.py
@@ -231,8 +231,9 @@ def _convert_tf2_model(flags):
     model = keras.models.load_model(flags.keras_model_file)
     converter = lite.TFLiteConverterV2.from_keras_model(model)
 
-  if flags.experimental_new_converter:
-    converter.experimental_new_converter = True
+  # TODO(b/145312675): Enable the new converter by default. It requires to
+  # add a new command line argument like `experimental_legacy_converter`.
+  converter.experimental_new_converter = flags.experimental_new_converter
 
   # Convert the model.
   tflite_model = converter.convert()

From e3a57c68ed0f7ea0ba991893756fcbfbf615a65e Mon Sep 17 00:00:00 2001
From: Gaurav Jain 
Date: Tue, 3 Dec 2019 16:03:59 -0800
Subject: [PATCH 253/279] Speed up 0/1 detection in tf.random_uniform

PiperOrigin-RevId: 283642578
Change-Id: If167158639d3dec2778a0a82f16790b8b7c34236
---
 tensorflow/python/ops/random_ops.py | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 94ada0515cf..f9208cca551 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -19,7 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 import numpy as np
-import six
 
 from tensorflow.python.compat import compat
 from tensorflow.python.eager import context
@@ -266,11 +265,10 @@ def random_uniform(shape,
     shape = tensor_util.shape_tensor(shape)
     # TODO(b/143079601): Remove this once the compatible window is passed.
     if compat.forward_compatible(2019, 12, 3):
-      # In case of [0,1) floating results, minval and maxval is unused.
-      minval_is_zero = isinstance(minval, six.integer_types +
-                                  (float,)) and minval == 0
-      maxval_is_one = isinstance(maxval, six.integer_types +
-                                 (float,)) and maxval == 1
+      # In case of [0,1) floating results, minval and maxval is unused. We do an
+      # `is` comparison here since this is cheaper than isinstance or  __eq__.
+      minval_is_zero = minval is 0  # pylint: disable=literal-comparison
+      maxval_is_one = maxval is 1  # pylint: disable=literal-comparison
       if not minval_is_zero or not maxval_is_one or dtype.is_integer:
         minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
         maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")

From 5b3a7fdcb9342fe2be0e6dc7359e5cdca8118c37 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 16:05:46 -0800
Subject: [PATCH 254/279] Add a pass to legalize operations before lowering to
 SPIR-V.

Not all StandardOps can be lowered to SPIR-V. For example, subview op
implementation requires use of pointer bitcasts which is not valid
according to SPIR-V spec (or at least is ambiguous about it). Such ops
need to be removed/transformed before lowering to SPIR-V. The
SPIRVLegalizationPass is added a place where such legalizations can be
added. Current implementation folds the subview ops with load/stores
so that the lowering itself does not have to convert a subview op.

PiperOrigin-RevId: 283642981
Change-Id: I9bcb40111871a98ed87e6f593ab89bbdc6b616c1
---
 third_party/mlir/BUILD                        |   1 +
 .../StandardToSPIRV/ConvertStandardToSPIRV.h  |   9 +-
 .../ConvertStandardToSPIRVPass.h              |   5 +
 .../include/mlir/Dialect/StandardOps/Ops.td   |   5 +
 .../Conversion/StandardToSPIRV/CMakeLists.txt |   1 +
 .../ConvertStandardToSPIRV.cpp                |   2 +-
 .../LegalizeStandardForSPIRV.cpp              | 192 ++++++++++++++++++
 .../mlir/lib/Dialect/StandardOps/Ops.cpp      |  36 ++++
 8 files changed, 249 insertions(+), 2 deletions(-)
 create mode 100644 third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp

diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 73fbf86cde7..17e7c5b58c9 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -1137,6 +1137,7 @@ cc_library(
     srcs = [
         "lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp",
         "lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp",
+        "lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp",
     ],
     hdrs = [
         "include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h",
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
index 69db8171ed0..4caa6d9de77 100644
--- a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
+++ b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
@@ -26,12 +26,19 @@
 
 namespace mlir {
 class SPIRVTypeConverter;
+
 /// Appends to a pattern list additional patterns for translating StandardOps to
-/// SPIR-V ops.
+/// SPIR-V ops. Also adds the patterns legalize ops not directly translated to
+/// SPIR-V dialect.
 void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns);
 
+/// Appends to a pattern list patterns to legalize ops that are not directly
+/// lowered to SPIR-V.
+void populateStdLegalizationPatternsForSPIRVLowering(
+    MLIRContext *context, OwningRewritePatternList &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h
index 1bf497708de..e8a71feb8b2 100644
--- a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h
+++ b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h
@@ -25,8 +25,13 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+
 /// Pass to convert StandardOps to SPIR-V ops.
 std::unique_ptr> createConvertStandardToSPIRVPass();
+
+/// Pass to legalize ops that are not directly lowered to SPIR-V.
+std::unique_ptr createLegalizeStdOpsForSPIRVLoweringPass();
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H
diff --git a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 7617d3cb247..51c7bfbccdc 100644
--- a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1397,6 +1397,11 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
     /// Returns the dynamic sizes for this subview operation if specified.
     operand_range getDynamicSizes() { return sizes(); }
 
+    /// Returns in `staticStrides` the static value of the stride
+    /// operands. Returns failure() if the static value of the stride
+    /// operands could not be retrieved.
+    LogicalResult getStaticStrides(SmallVectorImpl &staticStrides);
+
     // Auxiliary range data structure and helper function that unpacks the
     // offset, size and stride operands of the SubViewOp into a list of triples.
     // Such a list of triple is sometimes more convenient to manipulate.
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
index 351216216f1..fcced23a95e 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
@@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
 add_llvm_library(MLIRStandardToSPIRVTransforms
   ConvertStandardToSPIRV.cpp
   ConvertStandardToSPIRVPass.cpp
+  LegalizeStandardForSPIRV.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index ee2dfedc15b..c2ca4c94878 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -262,8 +262,8 @@ namespace mlir {
 void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns) {
+  // Add patterns that lower operations into SPIR-V dialect.
   populateWithGenerated(context, &patterns);
-  // Add the return op conversion.
   patterns
       .insert,
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
new file mode 100644
index 00000000000..1e8afbf43e1
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -0,0 +1,192 @@
+//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This transformation pass legalizes operations before the conversion to SPIR-V
+// dialect to handle ops that cannot be lowered directly.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Merges subview operation with load operation.
+class LoadOpOfSubViewFolder final : public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(LoadOp loadOp,
+                                     PatternRewriter &rewriter) const override;
+};
+
+/// Merges subview operation with store operation.
+class StoreOpOfSubViewFolder final : public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(StoreOp storeOp,
+                                     PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Utility functions for op legalization.
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+///          memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+///          memref<12x42xf32>
+static LogicalResult
+resolveSourceIndices(Location loc, PatternRewriter &rewriter,
+                     SubViewOp subViewOp, ArrayRef indices,
+                     SmallVectorImpl &sourceIndices) {
+  // TODO: Aborting when the offsets are static. There might be a way to fold
+  // the subview op with load even if the offsets have been canonicalized
+  // away.
+  if (subViewOp.getNumOffsets() == 0)
+    return failure();
+
+  SmallVector opOffsets = llvm::to_vector<2>(subViewOp.offsets());
+  SmallVector opStrides;
+  if (subViewOp.getNumStrides()) {
+    // If the strides are dynamic, get the stride operands.
+    opStrides = llvm::to_vector<2>(subViewOp.strides());
+  } else {
+    // When static, the stride operands can be retrieved by taking the strides
+    // of the result of the subview op, and dividing the strides of the base
+    // memref.
+    SmallVector staticStrides;
+    if (failed(subViewOp.getStaticStrides(staticStrides))) {
+      return failure();
+    }
+    opStrides.reserve(opOffsets.size());
+    for (auto stride : staticStrides) {
+      auto constValAttr = rewriter.getIntegerAttr(
+          IndexType::get(rewriter.getContext()), stride);
+      opStrides.emplace_back(rewriter.create(loc, constValAttr));
+    }
+  }
+  assert(opOffsets.size() == opStrides.size());
+
+  // New indices for the load are the current indices * subview_stride +
+  // subview_offset.
+  assert(indices.size() == opStrides.size());
+  sourceIndices.resize(indices.size());
+  for (auto index : enumerate(indices)) {
+    auto offset = opOffsets[index.index()];
+    auto stride = opStrides[index.index()];
+    auto mul = rewriter.create(loc, index.value(), stride);
+    sourceIndices[index.index()] =
+        rewriter.create(loc, offset, mul).getResult();
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and LoadOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
+                                       PatternRewriter &rewriter) const {
+  auto subViewOp =
+      dyn_cast_or_null(loadOp.memref()->getDefiningOp());
+  if (!subViewOp) {
+    return matchFailure();
+  }
+  SmallVector sourceIndices,
+      indices = llvm::to_vector<4>(loadOp.indices());
+  if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, indices,
+                                  sourceIndices)))
+    return matchFailure();
+
+  rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(),
+                                      sourceIndices);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and StoreOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
+                                        PatternRewriter &rewriter) const {
+  auto subViewOp =
+      dyn_cast_or_null(storeOp.memref()->getDefiningOp());
+  if (!subViewOp) {
+    return matchFailure();
+  }
+  SmallVector sourceIndices,
+      indices = llvm::to_vector<4>(storeOp.indices());
+  if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
+                                  indices, sourceIndices)))
+    return matchFailure();
+
+  rewriter.replaceOpWithNewOp(storeOp, storeOp.value(),
+                                       subViewOp.source(), sourceIndices);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Hook for adding patterns.
+//===----------------------------------------------------------------------===//
+
+void mlir::populateStdLegalizationPatternsForSPIRVLowering(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Pass for testing just the legalization patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct SPIRVLegalization final : public OperationPass {
+  void runOnOperation() override;
+};
+} // namespace
+
+void SPIRVLegalization::runOnOperation() {
+  OwningRewritePatternList patterns;
+  auto *context = &getContext();
+  populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
+  applyPatternsGreedily(getOperation()->getRegions(), patterns);
+}
+
+std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
+  return std::make_unique();
+}
+
+static PassRegistration
+    pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering");
diff --git a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
index 361135c4e29..9f6510d0f17 100644
--- a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -2761,6 +2761,42 @@ SmallVector SubViewOp::getRanges() {
   return res;
 }
 
+LogicalResult
+SubViewOp::getStaticStrides(SmallVectorImpl &staticStrides) {
+  // If the strides are dynamic return failure.
+  if (getNumStrides())
+    return failure();
+
+  // When static, the stride operands can be retrieved by taking the strides of
+  // the result of the subview op, and dividing the strides of the base memref.
+  int64_t resultOffset, baseOffset;
+  SmallVector resultStrides, baseStrides;
+  if (failed(
+          getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) ||
+      llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) ||
+      failed(getStridesAndOffset(getType(), resultStrides, resultOffset)))
+    return failure();
+
+  assert(static_cast(resultStrides.size()) == getType().getRank() &&
+         baseStrides.size() == resultStrides.size() &&
+         "base and result memrefs must have the same rank");
+  assert(!llvm::is_contained(resultStrides,
+                             MemRefType::getDynamicStrideOrOffset()) &&
+         "strides of subview op must be static, when there are no dynamic "
+         "strides specified");
+  staticStrides.resize(getType().getRank());
+  for (auto resultStride : enumerate(resultStrides)) {
+    auto baseStride = baseStrides[resultStride.index()];
+    // The result stride is expected to be a multiple of the base stride. Abort
+    // if that is not the case.
+    if (resultStride.value() < baseStride ||
+        resultStride.value() % baseStride != 0)
+      return failure();
+    staticStrides[resultStride.index()] = resultStride.value() / baseStride;
+  }
+  return success();
+}
+
 static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) {
   if (memrefType.getNumDynamicDims() > 0)
     return false;

From 9c91ebe6ec5311c4f2b572a2ee8d896924f350ea Mon Sep 17 00:00:00 2001
From: Bixia Zheng 
Date: Tue, 3 Dec 2019 16:19:06 -0800
Subject: [PATCH 255/279] [TF:MLIR] Fix windows build.

Remove the extra ModulePass for the Standardpipeline Options.

PiperOrigin-RevId: 283645653
Change-Id: Id4fb7efcab585701af005b96ba05fd59d37b3546
---
 .../compiler/mlir/tensorflow/transforms/bridge.cc   |  2 +-
 .../compiler/mlir/tensorflow/transforms/optimize.cc |  4 ++--
 .../compiler/mlir/tensorflow/transforms/passes.h    | 13 +++++--------
 3 files changed, 8 insertions(+), 11 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
index d9bae902382..a7f45c41f15 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
@@ -76,7 +76,7 @@ tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
   if (enable_logging)
     bridge.addInstrumentation(std::make_unique());
 
-  StandardPipeline::Options pipeline_options;
+  StandardPipelineOptions pipeline_options;
   pipeline_options.enable_inliner.setValue(enable_inliner);
   CreateTFStandardPipeline(bridge, pipeline_options);
   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
index 9dd5accc81c..b0420663bde 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
@@ -46,7 +46,7 @@ struct TFOptimizePass : public FunctionPass {
 
 // NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
 void CreateTFStandardPipeline(OpPassManager &pm,
-                              const StandardPipeline::Options &options) {
+                              const StandardPipelineOptions &options) {
   OpPassManager &func_pm = pm.nest();
 
   // First operates on the executor dialect:
@@ -75,7 +75,7 @@ std::unique_ptr> CreateTFOptimizePass() {
 static PassRegistration pass("tf-optimize", "Optimizes TF.");
 
 // Registers a pipeline builder function for the default canonicalize/optimizer.
-static mlir::PassPipelineRegistration pipeline(
+static mlir::PassPipelineRegistration pipeline(
     "tf-standard-pipeline",
     "Run all the passes involved in transforming/optimizing the graph after "
     "importing into MLIR, without any target specialization.",
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index 49458c4eac6..30ee91f4aea 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -46,20 +46,17 @@ std::unique_ptr> CreateTFShapeInferencePass();
 // Optimizes Tensorflow graph.
 std::unique_ptr> CreateTFOptimizePass();
 
-class StandardPipeline : public ModulePass {
- public:
-  struct Options : public PassOptions {
-    Option enable_inliner{*this, "enable-inliner",
-                                llvm::cl::desc("Enable inliner."),
-                                llvm::cl::init(false)};
-  };
+struct StandardPipelineOptions : public PassOptions {
+  Option enable_inliner{*this, "enable-inliner",
+                              llvm::cl::desc("Enable inliner."),
+                              llvm::cl::init(false)};
 };
 
 // Propagates the pass manager with the passes involved in transforming or
 // optimizing an MLIR graph without any target specialization.
 // NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
 void CreateTFStandardPipeline(OpPassManager& pm,
-                              const StandardPipeline::Options& options);
+                              const StandardPipelineOptions& options);
 }  // namespace TF
 
 namespace TFControlFlow {

From b287334222091f2df65028b5152f239bf6ada114 Mon Sep 17 00:00:00 2001
From: Shanqing Cai 
Date: Tue, 3 Dec 2019 16:23:17 -0800
Subject: [PATCH 256/279] [tfdbg] Make sure that file readers are closed
 properly after tests

for best practice of testing

To this end,
- The DebugEventsReader class is turned into a context manager, the
  __exit__() call to which closes the files
- Let more unit tests inherit a common DumpingCallbackTestBase base class,
  to reduce the amount of duplicate logic slightly.

Also in this change:
- Use OrderedDict in lieu of dict in some of the return values of
  the base test class.

PiperOrigin-RevId: 283646428
Change-Id: Id624a53fa4d755ec94ad1110ac6aabcb62144176
---
 tensorflow/python/debug/BUILD                 |   2 +
 .../python/debug/lib/debug_events_reader.py   |  11 +
 .../debug/lib/debug_events_writer_test.py     | 137 +++----
 .../python/debug/lib/debug_v2_ops_test.py     | 166 ++++----
 .../python/debug/lib/dumping_callback_test.py | 356 +++++++++---------
 .../debug/lib/dumping_callback_test_lib.py    | 314 +++++++--------
 6 files changed, 487 insertions(+), 499 deletions(-)

diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 97fe48ee165..2bc35ef52af 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -654,6 +654,7 @@ py_test(
     deps = [
         ":debug_events_reader",
         ":debug_events_writer",
+        ":dumping_callback_test_lib",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_test_lib",
@@ -767,6 +768,7 @@ cuda_py_test(
     additional_deps = [
         ":debug_events_reader",
         ":debug_events_writer",
+        ":dumping_callback_test_lib",
         "//third_party/py/numpy",
         "//tensorflow/python:debug_ops_gen",
         "//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/python/debug/lib/debug_events_reader.py b/tensorflow/python/debug/lib/debug_events_reader.py
index 2a9f331439b..a20cc175ebb 100644
--- a/tensorflow/python/debug/lib/debug_events_reader.py
+++ b/tensorflow/python/debug/lib/debug_events_reader.py
@@ -55,6 +55,13 @@ class DebugEventsReader(object):
     self._readers = dict()  # A map from file path to reader.
     self._readers_lock = threading.Lock()
 
+  def __enter__(self):
+    return self
+
+  def __exit__(self, exception_type, exception_value, traceback):
+    del exception_type, exception_value, traceback  # Unused
+    self.close()
+
   def _generic_iterator(self, file_path):
     """A helper method that makes an iterator given a debug-events file path."""
     # The following code uses the double-checked locking pattern to optimize
@@ -93,3 +100,7 @@ class DebugEventsReader(object):
 
   def graph_execution_traces_iterator(self):
     return self._generic_iterator(self._graph_execution_traces_path)
+
+  def close(self):
+    for reader in self._readers.values():
+      reader.Close()
diff --git a/tensorflow/python/debug/lib/debug_events_writer_test.py b/tensorflow/python/debug/lib/debug_events_writer_test.py
index 5c85ec6dcdc..86e7fd26e1a 100644
--- a/tensorflow/python/debug/lib/debug_events_writer_test.py
+++ b/tensorflow/python/debug/lib/debug_events_writer_test.py
@@ -20,31 +20,19 @@ from __future__ import print_function
 
 import glob
 import os
-import tempfile
 import threading
 
 from tensorflow.core.protobuf import debug_event_pb2
 from tensorflow.python.debug.lib import debug_events_reader
 from tensorflow.python.debug.lib import debug_events_writer
+from tensorflow.python.debug.lib import dumping_callback_test_lib
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
-from tensorflow.python.lib.io import file_io
 from tensorflow.python.platform import googletest
 
 
-class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
-
-  def setUp(self):
-    super(PywrapeventsWriterTest, self).setUp()
-    self.dump_root = tempfile.mkdtemp()
-
-  def tearDown(self):
-    if os.path.isdir(self.dump_root):
-      file_io.delete_recursively(self.dump_root)
-    super(PywrapeventsWriterTest, self).tearDown()
+class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
 
   def testMultiThreadedConstructorCallWorks(self):
-
     def InitWriter():
       debug_events_writer.DebugEventsWriter(self.dump_root)
 
@@ -68,14 +56,7 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
     self.assertEqual(len(stack_frames_paths), 1)
     graphs_paths = glob.glob(os.path.join(self.dump_root, "*.graphs"))
     self.assertEqual(len(graphs_paths), 1)
-
-    # Verify the content of the metadata file.
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    metadata_iter = reader.metadata_iterator()
-    debug_event = next(metadata_iter)
-    self.assertTrue(debug_event.debug_metadata.tensorflow_version)
-    self.assertTrue(
-        debug_event.debug_metadata.file_version.startswith("debug.Event:"))
+    self._readAndCheckMetadataFile()
 
   def testWriteSourceFilesAndStackFrames(self):
     writer = debug_events_writer.DebugEventsWriter(self.dump_root)
@@ -94,21 +75,21 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
 
     writer.FlushNonExecutionFiles()
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    actuals = list(reader.source_files_iterator())
-    self.assertLen(actuals, num_protos)
-    for i in range(num_protos):
-      self.assertEqual(actuals[i].source_file.file_path,
-                       "/home/tf2user/main.py")
-      self.assertEqual(actuals[i].source_file.host_name, "machine.cluster")
-      self.assertEqual(actuals[i].source_file.lines, ["print(%d)" % i])
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      actuals = list(reader.source_files_iterator())
+      self.assertLen(actuals, num_protos)
+      for i in range(num_protos):
+        self.assertEqual(actuals[i].source_file.file_path,
+                         "/home/tf2user/main.py")
+        self.assertEqual(actuals[i].source_file.host_name, "machine.cluster")
+        self.assertEqual(actuals[i].source_file.lines, ["print(%d)" % i])
 
-    actuals = list(reader.stack_frames_iterator())
-    self.assertLen(actuals, num_protos)
-    for i in range(num_protos):
-      self.assertEqual(actuals[i].stack_frame_with_id.id, "stack_%d" % i)
-      self.assertEqual(actuals[i].stack_frame_with_id.file_line_col.file_index,
-                       i * 10)
+      actuals = list(reader.stack_frames_iterator())
+      self.assertLen(actuals, num_protos)
+      for i in range(num_protos):
+        self.assertEqual(actuals[i].stack_frame_with_id.id, "stack_%d" % i)
+        self.assertEqual(
+            actuals[i].stack_frame_with_id.file_line_col.file_index, i * 10)
 
   def testWriteGraphOpCreationAndDebuggedGraphs(self):
     writer = debug_events_writer.DebugEventsWriter(self.dump_root)
@@ -188,15 +169,15 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
     for thread in threads:
       thread.join()
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
     # Verify the content of the .source_files file.
-    source_files_iter = reader.source_files_iterator()
-    actuals = list(source_files_iter)
-    file_paths = sorted([actual.source_file.file_path for actual in actuals])
-    self.assertEqual(file_paths, [
-        "/home/tf2user/file_0.py", "/home/tf2user/file_1.py",
-        "/home/tf2user/file_2.py"
-    ])
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      source_files_iter = reader.source_files_iterator()
+      actuals = list(source_files_iter)
+      file_paths = sorted([actual.source_file.file_path for actual in actuals])
+      self.assertEqual(file_paths, [
+          "/home/tf2user/file_0.py", "/home/tf2user/file_1.py",
+          "/home/tf2user/file_2.py"
+      ])
 
     # Verify the content of the .stack_frames file.
     actuals = list(reader.stack_frames_iterator())
@@ -219,18 +200,16 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
       execution.op_type = "OpType%d" % i
       writer.WriteExecution(execution)
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    actuals = list(reader.execution_iterator())
     # Before FlushExecutionFiles() is called. No data should have been written
     # to the file.
-    self.assertEqual(len(actuals), 0)
+    executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
+    self.assertFalse(executed_op_types)
 
     writer.FlushExecutionFiles()
-    actuals = list(reader.execution_iterator())
-    self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)
-    for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE):
+    executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
+    for i, executed_op_type in enumerate(executed_op_types):
       self.assertEqual(
-          actuals[i].execution.op_type,
+          executed_op_type,
           "OpType%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE))
 
   def testWriteExecutionEventsWithoutCircularBufferBehavior(self):
@@ -243,11 +222,10 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
       writer.WriteExecution(execution)
     writer.FlushExecutionFiles()
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    actuals = list(reader.execution_iterator())
-    self.assertLen(actuals, num_execution_events)
-    for i in range(num_execution_events):
-      self.assertEqual(actuals[i].execution.op_type, "OpType%d" % i)
+    executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
+    self.assertLen(executed_op_types, num_execution_events)
+    for i, executed_op_type in enumerate(executed_op_types):
+      self.assertEqual(executed_op_type, "OpType%d" % i)
 
   def testWriteGraphExecutionTraceEventsWithCircularBuffer(self):
     writer = debug_events_writer.DebugEventsWriter(self.dump_root)
@@ -257,19 +235,19 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
       trace.op_name = "Op%d" % i
       writer.WriteGraphExecutionTrace(trace)
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    actuals = list(reader.graph_execution_traces_iterator())
-    # Before FlushExecutionFiles() is called. No data should have been written
-    # to the file.
-    self.assertEqual(len(actuals), 0)
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      actuals = list(reader.graph_execution_traces_iterator())
+      # Before FlushExecutionFiles() is called. No data should have been written
+      # to the file.
+      self.assertEqual(len(actuals), 0)
 
-    writer.FlushExecutionFiles()
-    actuals = list(reader.graph_execution_traces_iterator())
-    self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)
-    for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE):
-      self.assertEqual(
-          actuals[i].graph_execution_trace.op_name,
-          "Op%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE))
+      writer.FlushExecutionFiles()
+      actuals = list(reader.graph_execution_traces_iterator())
+      self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)
+      for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE):
+        self.assertEqual(
+            actuals[i].graph_execution_trace.op_name,
+            "Op%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE))
 
   def testWriteGraphExecutionTraceEventsWithoutCircularBufferBehavior(self):
     # A circular buffer size of 0 abolishes the circular buffer behavior.
@@ -281,8 +259,8 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
       writer.WriteGraphExecutionTrace(trace)
     writer.FlushExecutionFiles()
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    actuals = list(reader.graph_execution_traces_iterator())
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      actuals = list(reader.graph_execution_traces_iterator())
     self.assertLen(actuals, num_execution_events)
     for i in range(num_execution_events):
       self.assertEqual(actuals[i].graph_execution_trace.op_name, "Op%d" % i)
@@ -324,18 +302,17 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
     writer.FlushExecutionFiles()
 
     # Verify the content of the .execution file.
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    actuals = list(reader.execution_iterator())
-    op_types = sorted([actual.execution.op_type for actual in actuals])
-    self.assertLen(op_types, circular_buffer_size)
-    self.assertLen(op_types, len(set(op_types)))
+    executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
+    self.assertLen(executed_op_types, circular_buffer_size)
+    self.assertLen(executed_op_types, len(set(executed_op_types)))
 
     # Verify the content of the .execution file.
-    actuals = list(reader.graph_execution_traces_iterator())
-    op_names = sorted(
-        [actual.graph_execution_trace.op_name for actual in actuals])
-    self.assertLen(op_names, circular_buffer_size)
-    self.assertLen(op_names, len(set(op_names)))
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      actuals = list(reader.graph_execution_traces_iterator())
+      op_names = sorted(
+          [actual.graph_execution_trace.op_name for actual in actuals])
+      self.assertLen(op_names, circular_buffer_size)
+      self.assertLen(op_names, len(set(op_names)))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/debug/lib/debug_v2_ops_test.py b/tensorflow/python/debug/lib/debug_v2_ops_test.py
index f4a8b46352c..08b0ec17316 100644
--- a/tensorflow/python/debug/lib/debug_v2_ops_test.py
+++ b/tensorflow/python/debug/lib/debug_v2_ops_test.py
@@ -19,30 +19,28 @@ from __future__ import division
 from __future__ import print_function
 
 import os
-import tempfile
 
 import numpy as np
 
 from tensorflow.core.protobuf import debug_event_pb2
 from tensorflow.python.debug.lib import debug_events_reader
 from tensorflow.python.debug.lib import debug_events_writer
+from tensorflow.python.debug.lib import dumping_callback_test_lib
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_util
-from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import gen_debug_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import googletest
 
 
-class DebugIdentityV2OpTest(test_util.TensorFlowTestCase):
+class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase):
 
   def setUp(self):
     super(DebugIdentityV2OpTest, self).setUp()
-    self.dump_root = tempfile.mkdtemp()
     # Testing using a small circular-buffer size.
     self.circular_buffer_size = 4
     self.writer = debug_events_writer.DebugEventsWriter(
@@ -50,8 +48,6 @@ class DebugIdentityV2OpTest(test_util.TensorFlowTestCase):
 
   def tearDown(self):
     self.writer.Close()
-    if os.path.isdir(self.dump_root):
-      file_io.delete_recursively(self.dump_root)
     super(DebugIdentityV2OpTest, self).tearDown()
 
   @test_util.run_in_graph_and_eager_modes
@@ -87,55 +83,55 @@ class DebugIdentityV2OpTest(test_util.TensorFlowTestCase):
       self.assertAllClose(
           write_debug_trace(x), [9.0 + np.sqrt(3.0), 16.0 + 2.0])
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    metadata_iter = reader.metadata_iterator()
-    # Check that the .metadata DebugEvents data file has been created, even
-    # before FlushExecutionFiles() is called.
-    debug_event = next(metadata_iter)
-    self.assertGreater(debug_event.wall_time, 0)
-    self.assertTrue(debug_event.debug_metadata.tensorflow_version)
-    self.assertTrue(
-        debug_event.debug_metadata.file_version.startswith("debug.Event:"))
-
-    graph_trace_iter = reader.graph_execution_traces_iterator()
-    # Before FlushExecutionFiles() is called, the .graph_execution_traces file
-    # ought to be empty.
-    with self.assertRaises(StopIteration):
-      next(graph_trace_iter)
-
-    # Flush the circular buffer.
-    self.writer.FlushExecutionFiles()
-    graph_trace_iter = reader.graph_execution_traces_iterator()
-
-    # The circular buffer has a size of 4. So only the data from the
-    # last two iterations should have been written to self.dump_root.
-    for _ in range(2):
-      debug_event = next(graph_trace_iter)
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      metadata_iter = reader.metadata_iterator()
+      # Check that the .metadata DebugEvents data file has been created, even
+      # before FlushExecutionFiles() is called.
+      debug_event = next(metadata_iter)
       self.assertGreater(debug_event.wall_time, 0)
-      trace = debug_event.graph_execution_trace
-      self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
-      self.assertEqual(trace.op_name, "Square")
-      self.assertEqual(trace.output_slot, 0)
-      self.assertEqual(trace.tensor_debug_mode,
-                       debug_event_pb2.TensorDebugMode.FULL_TENSOR)
-      tensor_value = tensor_util.MakeNdarray(trace.tensor_proto)
-      self.assertAllClose(tensor_value, [9.0, 16.0])
+      self.assertTrue(debug_event.debug_metadata.tensorflow_version)
+      self.assertTrue(
+          debug_event.debug_metadata.file_version.startswith("debug.Event:"))
 
-      debug_event = next(graph_trace_iter)
-      self.assertGreater(debug_event.wall_time, 0)
-      trace = debug_event.graph_execution_trace
-      self.assertEqual(trace.tfdbg_context_id, "beafdead")
-      self.assertEqual(trace.op_name, "Sqrt")
-      self.assertEqual(trace.output_slot, 0)
-      self.assertEqual(trace.tensor_debug_mode,
-                       debug_event_pb2.TensorDebugMode.FULL_TENSOR)
-      tensor_value = tensor_util.MakeNdarray(trace.tensor_proto)
-      self.assertAllClose(tensor_value, [np.sqrt(3.0), 2.0])
+      graph_trace_iter = reader.graph_execution_traces_iterator()
+      # Before FlushExecutionFiles() is called, the .graph_execution_traces file
+      # ought to be empty.
+      with self.assertRaises(StopIteration):
+        next(graph_trace_iter)
 
-    # Only the graph-execution trace of the last iteration should be written
-    # to self.dump_root.
-    with self.assertRaises(StopIteration):
-      next(graph_trace_iter)
+      # Flush the circular buffer.
+      self.writer.FlushExecutionFiles()
+      graph_trace_iter = reader.graph_execution_traces_iterator()
+
+      # The circular buffer has a size of 4. So only the data from the
+      # last two iterations should have been written to self.dump_root.
+      for _ in range(2):
+        debug_event = next(graph_trace_iter)
+        self.assertGreater(debug_event.wall_time, 0)
+        trace = debug_event.graph_execution_trace
+        self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
+        self.assertEqual(trace.op_name, "Square")
+        self.assertEqual(trace.output_slot, 0)
+        self.assertEqual(trace.tensor_debug_mode,
+                         debug_event_pb2.TensorDebugMode.FULL_TENSOR)
+        tensor_value = tensor_util.MakeNdarray(trace.tensor_proto)
+        self.assertAllClose(tensor_value, [9.0, 16.0])
+
+        debug_event = next(graph_trace_iter)
+        self.assertGreater(debug_event.wall_time, 0)
+        trace = debug_event.graph_execution_trace
+        self.assertEqual(trace.tfdbg_context_id, "beafdead")
+        self.assertEqual(trace.op_name, "Sqrt")
+        self.assertEqual(trace.output_slot, 0)
+        self.assertEqual(trace.tensor_debug_mode,
+                         debug_event_pb2.TensorDebugMode.FULL_TENSOR)
+        tensor_value = tensor_util.MakeNdarray(trace.tensor_proto)
+        self.assertAllClose(tensor_value, [np.sqrt(3.0), 2.0])
+
+      # Only the graph-execution trace of the last iteration should be written
+      # to self.dump_root.
+      with self.assertRaises(StopIteration):
+        next(graph_trace_iter)
 
   @test_util.run_in_graph_and_eager_modes
   def testControlFlow(self):
@@ -162,28 +158,28 @@ class DebugIdentityV2OpTest(test_util.TensorFlowTestCase):
     self.evaluate(collatz(x))
 
     self.writer.FlushExecutionFiles()
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    graph_trace_iter = reader.graph_execution_traces_iterator()
-    try:
-      x_values = []
-      timestamp = 0
-      while True:
-        debug_event = next(graph_trace_iter)
-        self.assertGreater(debug_event.wall_time, timestamp)
-        timestamp = debug_event.wall_time
-        trace = debug_event.graph_execution_trace
-        self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
-        self.assertEqual(trace.op_name, "x")
-        self.assertEqual(trace.output_slot, 0)
-        self.assertEqual(trace.tensor_debug_mode,
-                         debug_event_pb2.TensorDebugMode.FULL_TENSOR)
-        x_values.append(int(tensor_util.MakeNdarray(trace.tensor_proto)))
-    except StopIteration:
-      pass
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      graph_trace_iter = reader.graph_execution_traces_iterator()
+      try:
+        x_values = []
+        timestamp = 0
+        while True:
+          debug_event = next(graph_trace_iter)
+          self.assertGreater(debug_event.wall_time, timestamp)
+          timestamp = debug_event.wall_time
+          trace = debug_event.graph_execution_trace
+          self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
+          self.assertEqual(trace.op_name, "x")
+          self.assertEqual(trace.output_slot, 0)
+          self.assertEqual(trace.tensor_debug_mode,
+                           debug_event_pb2.TensorDebugMode.FULL_TENSOR)
+          x_values.append(int(tensor_util.MakeNdarray(trace.tensor_proto)))
+      except StopIteration:
+        pass
 
-    # Due to the circular buffer, only the last 4 iterations of
-    # [10, 5, 16, 8, 4, 2] should have been written.
-    self.assertAllEqual(x_values, [16, 8, 4, 2])
+      # Due to the circular buffer, only the last 4 iterations of
+      # [10, 5, 16, 8, 4, 2] should have been written.
+      self.assertAllEqual(x_values, [16, 8, 4, 2])
 
   @test_util.run_in_graph_and_eager_modes
   def testTwoDumpRoots(self):
@@ -210,20 +206,20 @@ class DebugIdentityV2OpTest(test_util.TensorFlowTestCase):
     another_writer.Close()
 
     for debug_root in (self.dump_root, another_dump_root):
-      reader = debug_events_reader.DebugEventsReader(debug_root)
-      graph_trace_iter = reader.graph_execution_traces_iterator()
+      with debug_events_reader.DebugEventsReader(debug_root) as reader:
+        graph_trace_iter = reader.graph_execution_traces_iterator()
 
-      debug_event = next(graph_trace_iter)
-      trace = debug_event.graph_execution_trace
-      self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
-      self.assertEqual(trace.op_name, "")
-      self.assertEqual(trace.tensor_debug_mode,
-                       debug_event_pb2.TensorDebugMode.FULL_TENSOR)
-      tensor_value = tensor_util.MakeNdarray(trace.tensor_proto)
-      self.assertAllClose(tensor_value, [9.0, 16.0])
+        debug_event = next(graph_trace_iter)
+        trace = debug_event.graph_execution_trace
+        self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
+        self.assertEqual(trace.op_name, "")
+        self.assertEqual(trace.tensor_debug_mode,
+                         debug_event_pb2.TensorDebugMode.FULL_TENSOR)
+        tensor_value = tensor_util.MakeNdarray(trace.tensor_proto)
+        self.assertAllClose(tensor_value, [9.0, 16.0])
 
-      with self.assertRaises(StopIteration):
-        next(graph_trace_iter)
+        with self.assertRaises(StopIteration):
+          next(graph_trace_iter)
 
   @test_util.run_in_graph_and_eager_modes
   def testDebugNumericSummaryV2OpReduceInfNanTwoSlots(self):
diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py
index ed222585454..d32d543b382 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test.py
@@ -112,84 +112,84 @@ class TracingCallbackTest(
 
     # Before FlushExecutionFiles() is called, the .execution file should be
     # empty.
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    execution_iter = reader.execution_iterator()
-    with self.assertRaises(StopIteration):
-      next(execution_iter)
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      execution_iter = reader.execution_iterator()
+      with self.assertRaises(StopIteration):
+        next(execution_iter)
 
-    # After the flushing, the .execution file should hold the appropriate
-    # contents.
-    writer.FlushExecutionFiles()
-    execution_iter = reader.execution_iterator()
-    prev_wall_time = 1
-    executed_op_types = []
-    tensor_values = collections.defaultdict(lambda: [])
-    for debug_event in execution_iter:
-      self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
-      prev_wall_time = debug_event.wall_time
-      execution = debug_event.execution
-      executed_op_types.append(execution.op_type)
-      self.assertTrue(execution.input_tensor_ids)
-      self.assertTrue(execution.output_tensor_ids)
-      if tensor_debug_mode == "NO_TENSOR":
-        # Due to the NO_TENSOR tensor debug mode, tensor_protos ought to
-        # be empty.
-        self.assertFalse(execution.tensor_protos)
-      elif tensor_debug_mode == "FULL_TENSOR":
-        # Under the FULL_TENSOR mode, the value of the tensor should be
-        # available through `tensor_protos`.
-        tensor_value = float(
-            tensor_util.MakeNdarray(execution.tensor_protos[0]))
-        tensor_values[execution.op_type].append(tensor_value)
-      # Verify the code_location field.
-      self.assertTrue(execution.code_location.stack_frame_ids)
-      for stack_frame_id in execution.code_location.stack_frame_ids:
-        self.assertIn(stack_frame_id, stack_frame_by_id)
-    if tensor_debug_mode == "FULL_TENSOR":
-      self.assertAllClose(tensor_values["Greater"], [1, 1, 1, 1, 1, 1, 0])
-      self.assertAllClose(tensor_values["RealDiv"], [5, 8, 4, 2, 1])
-      self.assertAllClose(tensor_values["Mul"], [15])
-      self.assertAllClose(tensor_values["AddV2"], [16])
+      # After the flushing, the .execution file should hold the appropriate
+      # contents.
+      writer.FlushExecutionFiles()
+      execution_iter = reader.execution_iterator()
+      prev_wall_time = 1
+      executed_op_types = []
+      tensor_values = collections.defaultdict(lambda: [])
+      for debug_event in execution_iter:
+        self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
+        prev_wall_time = debug_event.wall_time
+        execution = debug_event.execution
+        executed_op_types.append(execution.op_type)
+        self.assertTrue(execution.input_tensor_ids)
+        self.assertTrue(execution.output_tensor_ids)
+        if tensor_debug_mode == "NO_TENSOR":
+          # Due to the NO_TENSOR tensor debug mode, tensor_protos ought to
+          # be empty.
+          self.assertFalse(execution.tensor_protos)
+        elif tensor_debug_mode == "FULL_TENSOR":
+          # Under the FULL_TENSOR mode, the value of the tensor should be
+          # available through `tensor_protos`.
+          tensor_value = float(
+              tensor_util.MakeNdarray(execution.tensor_protos[0]))
+          tensor_values[execution.op_type].append(tensor_value)
+        # Verify the code_location field.
+        self.assertTrue(execution.code_location.stack_frame_ids)
+        for stack_frame_id in execution.code_location.stack_frame_ids:
+          self.assertIn(stack_frame_id, stack_frame_by_id)
+      if tensor_debug_mode == "FULL_TENSOR":
+        self.assertAllClose(tensor_values["Greater"], [1, 1, 1, 1, 1, 1, 0])
+        self.assertAllClose(tensor_values["RealDiv"], [5, 8, 4, 2, 1])
+        self.assertAllClose(tensor_values["Mul"], [15])
+        self.assertAllClose(tensor_values["AddV2"], [16])
 
-    self.assertEqual(
-        executed_op_types,
-        [
-            "Greater",
-            "FloorMod",
-            "Equal",
-            "RealDiv",  # 10 --> 5
-            "Greater",
-            "FloorMod",
-            "Equal",
-            "Mul",
-            "AddV2",  # 5 --> 16
-            "Greater",
-            "FloorMod",
-            "Equal",
-            "RealDiv",  # 16 --> 8
-            "Greater",
-            "FloorMod",
-            "Equal",
-            "RealDiv",  # 8 --> 4
-            "Greater",
-            "FloorMod",
-            "Equal",
-            "RealDiv",  # 4 --> 2
-            "Greater",
-            "FloorMod",
-            "Equal",
-            "RealDiv",  # 2 --> 1
-            "Greater"
-        ])
+      self.assertEqual(
+          executed_op_types,
+          [
+              "Greater",
+              "FloorMod",
+              "Equal",
+              "RealDiv",  # 10 --> 5
+              "Greater",
+              "FloorMod",
+              "Equal",
+              "Mul",
+              "AddV2",  # 5 --> 16
+              "Greater",
+              "FloorMod",
+              "Equal",
+              "RealDiv",  # 16 --> 8
+              "Greater",
+              "FloorMod",
+              "Equal",
+              "RealDiv",  # 8 --> 4
+              "Greater",
+              "FloorMod",
+              "Equal",
+              "RealDiv",  # 4 --> 2
+              "Greater",
+              "FloorMod",
+              "Equal",
+              "RealDiv",  # 2 --> 1
+              "Greater"
+          ])
 
-    # Due to the pure eager op execution, the .graph file and the
-    # .graph_execution_traces file ought to be empty.
-    graphs_iterator = reader.graphs_iterator()
-    with self.assertRaises(StopIteration):
-      next(graphs_iterator)
-    graph_trace_iter = reader.graph_execution_traces_iterator()
-    with self.assertRaises(StopIteration):
-      next(graph_trace_iter)
+      # Due to the pure eager op execution, the .graph file and the
+      # .graph_execution_traces file ought to be empty.
+      graphs_iterator = reader.graphs_iterator()
+      with self.assertRaises(StopIteration):
+        next(graphs_iterator)
+      graph_trace_iter = reader.graph_execution_traces_iterator()
+      with self.assertRaises(StopIteration):
+        next(graph_trace_iter)
 
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
@@ -425,66 +425,67 @@ class TracingCallbackTest(
 
     # Before FlushExecutionFiles() is called, the .execution and
     # .graph_execution_traces files should be both empty.
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    execution_iter = reader.execution_iterator()
-    graph_execution_traces_iter = reader.graph_execution_traces_iterator()
-    with self.assertRaises(StopIteration):
-      next(execution_iter)
-    with self.assertRaises(StopIteration):
-      next(graph_execution_traces_iter)
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      execution_iter = reader.execution_iterator()
+      graph_execution_traces_iter = reader.graph_execution_traces_iterator()
+      with self.assertRaises(StopIteration):
+        next(execution_iter)
+      with self.assertRaises(StopIteration):
+        next(graph_execution_traces_iter)
 
-    # TODO(cais): Backport execution instrumentation to tf.Session.
-    writer.FlushExecutionFiles()
-    # After the flushing, the .execution file should hold the appropriate
-    # contents.
-    if context.executing_eagerly():
-      (executed_op_types, input_tensor_ids, output_tensor_ids,
-       tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile()
-      # NOTE(b/142486213): Execution of the TF function happens with
-      # Session.run() in v1 graph mode, hence it doesn't get logged to the
-      # .execution file.
-      self.assertLen(executed_op_types, 1)
-      self.assertIn("iterative_doubling", executed_op_types[0])
-      self.assertLen(input_tensor_ids[0], 2)
-      self.assertLen(output_tensor_ids[0], 1)
-      self.assertEqual(tensor_debug_modes[0],
-                       debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode))
-      if tensor_debug_mode == "FULL_TENSOR":
-        self.assertAllClose(tensor_values, [[8.0]])
+      # TODO(cais): Backport execution instrumentation to tf.Session.
+      writer.FlushExecutionFiles()
+      # After the flushing, the .execution file should hold the appropriate
+      # contents.
+      if context.executing_eagerly():
+        (executed_op_types, input_tensor_ids, output_tensor_ids,
+         tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile()
+        # NOTE(b/142486213): Execution of the TF function happens with
+        # Session.run() in v1 graph mode, hence it doesn't get logged to the
+        # .execution file.
+        self.assertLen(executed_op_types, 1)
+        self.assertIn("iterative_doubling", executed_op_types[0])
+        self.assertLen(input_tensor_ids[0], 2)
+        self.assertLen(output_tensor_ids[0], 1)
+        self.assertEqual(
+            tensor_debug_modes[0],
+            debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode))
+        if tensor_debug_mode == "FULL_TENSOR":
+          self.assertAllClose(tensor_values, [[8.0]])
 
-    (op_names, _, output_slots,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-    # The Less op should have been executed 5 times.
-    self.assertEqual(executed_op_types.count("Less"), 5)
-    # The last executed op should be Less.
-    self.assertEqual(executed_op_types[-1], "Less")
-    # The Mul op should have been executed 4 times.
-    self.assertEqual(executed_op_types.count("Mul"), 4)
-    # The AddV2 op should have been run, but we refrain from asserting on how
-    # many times it's executed.
-    self.assertIn("AddV2", executed_op_types)
-    for output_slot in output_slots:
-      self.assertEqual(output_slot, 0)
-    if tensor_debug_mode == "NO_TENSOR":
-      # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
-      # be an empty float32 tensor.
-      for tensor_value in tensor_values:
-        self.assertEqual(tensor_value.dtype, np.float32)
-        self.assertEqual(tensor_value.shape, (0,))
-    elif tensor_debug_mode == "FULL_TENSOR":
-      less_values = [
-          tensor_values[i]
-          for i, op_type in enumerate(executed_op_types)
-          if op_type == "Less"
-      ]
-      self.assertAllClose(less_values, [True, True, True, True, False])
-      mul_values = [
-          tensor_values[i]
-          for i, op_type in enumerate(executed_op_types)
-          if op_type == "Mul"
-      ]
-      self.assertAllClose(mul_values, [1.0, 2.0, 4.0, 8.0])
+      (op_names, _, output_slots,
+       tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
+      executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
+      # The Less op should have been executed 5 times.
+      self.assertEqual(executed_op_types.count("Less"), 5)
+      # The last executed op should be Less.
+      self.assertEqual(executed_op_types[-1], "Less")
+      # The Mul op should have been executed 4 times.
+      self.assertEqual(executed_op_types.count("Mul"), 4)
+      # The AddV2 op should have been run, but we refrain from asserting on how
+      # many times it's executed.
+      self.assertIn("AddV2", executed_op_types)
+      for output_slot in output_slots:
+        self.assertEqual(output_slot, 0)
+      if tensor_debug_mode == "NO_TENSOR":
+        # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
+        # to be an empty float32 tensor.
+        for tensor_value in tensor_values:
+          self.assertEqual(tensor_value.dtype, np.float32)
+          self.assertEqual(tensor_value.shape, (0,))
+      elif tensor_debug_mode == "FULL_TENSOR":
+        less_values = [
+            tensor_values[i]
+            for i, op_type in enumerate(executed_op_types)
+            if op_type == "Less"
+        ]
+        self.assertAllClose(less_values, [True, True, True, True, False])
+        mul_values = [
+            tensor_values[i]
+            for i, op_type in enumerate(executed_op_types)
+            if op_type == "Mul"
+        ]
+        self.assertAllClose(mul_values, [1.0, 2.0, 4.0, 8.0])
 
   def testCallingEnableTracingTwiceWithTheSameDumpRootIsIdempotent(self):
     dumping_callback.enable_dump_debug_info(self.dump_root)
@@ -497,17 +498,17 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    execution_iter = reader.execution_iterator()
-    for _ in range(2):
-      debug_event = next(execution_iter)
-      self.assertGreater(debug_event.wall_time, 0)
-      execution = debug_event.execution
-      self.assertEqual(execution.op_type, "Unique")
-      self.assertEqual(execution.num_outputs, 2)
-      self.assertTrue(execution.code_location)
-    with self.assertRaises(StopIteration):
-      next(execution_iter)
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      execution_iter = reader.execution_iterator()
+      for _ in range(2):
+        debug_event = next(execution_iter)
+        self.assertGreater(debug_event.wall_time, 0)
+        execution = debug_event.execution
+        self.assertEqual(execution.op_type, "Unique")
+        self.assertEqual(execution.num_outputs, 2)
+        self.assertTrue(execution.code_location)
+      with self.assertRaises(StopIteration):
+        next(execution_iter)
 
   def testCallingEnableTracingTwiceWithDifferentDumpRootsOverwrites(self):
     dumping_callback.enable_dump_debug_info(self.dump_root)
@@ -521,23 +522,24 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    reader = debug_events_reader.DebugEventsReader(new_dump_root)
-    execution_iter = reader.execution_iterator()
-    for _ in range(2):
-      debug_event = next(execution_iter)
-      self.assertGreater(debug_event.wall_time, 0)
-      execution = debug_event.execution
-      self.assertEqual(execution.op_type, "Unique")
-      self.assertEqual(execution.num_outputs, 2)
-      self.assertTrue(execution.code_location)
-    with self.assertRaises(StopIteration):
-      next(execution_iter)
+    with debug_events_reader.DebugEventsReader(new_dump_root) as reader:
+      execution_iter = reader.execution_iterator()
+      for _ in range(2):
+        debug_event = next(execution_iter)
+        self.assertGreater(debug_event.wall_time, 0)
+        execution = debug_event.execution
+        self.assertEqual(execution.op_type, "Unique")
+        self.assertEqual(execution.num_outputs, 2)
+        self.assertTrue(execution.code_location)
+      with self.assertRaises(StopIteration):
+        next(execution_iter)
 
-    old_dump_root_reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    execution_iter = old_dump_root_reader.execution_iterator()
-    # The old dump root shouldn't have been written to.
-    with self.assertRaises(StopIteration):
-      next(execution_iter)
+      with debug_events_reader.DebugEventsReader(
+          self.dump_root) as old_dump_root_reader:
+        execution_iter = old_dump_root_reader.execution_iterator()
+        # The old dump root shouldn't have been written to.
+        with self.assertRaises(StopIteration):
+          next(execution_iter)
 
   def testCallingEnableRepeatedlyWithDifferentTensorDebugMode(self):
     """Assert that calling enable_dump_debug_info() with different tensor-debug modes.
@@ -586,17 +588,17 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    source_files_iter = reader.source_files_iterator()
-    stack_frames_iter = reader.stack_frames_iterator()
-    execution_iter = reader.execution_iterator()
-    # No source-file, stack-frame or execution data should have been dumped.
-    with self.assertRaises(StopIteration):
-      next(source_files_iter)
-    with self.assertRaises(StopIteration):
-      next(stack_frames_iter)
-    with self.assertRaises(StopIteration):
-      next(execution_iter)
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      source_files_iter = reader.source_files_iterator()
+      stack_frames_iter = reader.stack_frames_iterator()
+      execution_iter = reader.execution_iterator()
+      # No source-file, stack-frame or execution data should have been dumped.
+      with self.assertRaises(StopIteration):
+        next(source_files_iter)
+      with self.assertRaises(StopIteration):
+        next(stack_frames_iter)
+      with self.assertRaises(StopIteration):
+        next(execution_iter)
 
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
@@ -630,12 +632,12 @@ class TracingCallbackTest(
     writer.FlushExecutionFiles()
 
     stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    execution_iter = reader.execution_iterator()
-    prev_wall_time = 1
-    for debug_event in execution_iter:
-      self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
-      prev_wall_time = debug_event.wall_time
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      execution_iter = reader.execution_iterator()
+      prev_wall_time = 1
+      for debug_event in execution_iter:
+        self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
+        prev_wall_time = debug_event.wall_time
 
     (context_ids, _,
      op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
diff --git a/tensorflow/python/debug/lib/dumping_callback_test_lib.py b/tensorflow/python/debug/lib/dumping_callback_test_lib.py
index e572c48d04c..74261f918ce 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test_lib.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test_lib.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import os
 import shutil
 import socket
@@ -49,59 +50,60 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
 
   def _readAndCheckMetadataFile(self):
     """Read and check the .metadata debug-events file."""
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    metadata_iter = reader.metadata_iterator()
-    metadata = next(metadata_iter).debug_metadata
-    self.assertEqual(metadata.tensorflow_version, versions.__version__)
-    self.assertTrue(metadata.file_version.startswith("debug.Event"))
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      metadata_iter = reader.metadata_iterator()
+      metadata = next(metadata_iter).debug_metadata
+      self.assertEqual(metadata.tensorflow_version, versions.__version__)
+      self.assertTrue(metadata.file_version.startswith("debug.Event"))
 
   def _readAndCheckSourceFilesAndStackFrames(self):
     """Read and verify the .source_files & .stack_frames debug-event files.
 
     Returns:
-      A dict mapping stack frame IDs to stack frames (FileLineCol).
+      An OrderedDict mapping stack frame IDs to stack frames (FileLineCol).
     """
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    # Check the content of the .source_files file.
-    source_files_iter = reader.source_files_iterator()
-    source_file_paths = []
-    prev_wall_time = 1
-    for debug_event in source_files_iter:
-      self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
-      prev_wall_time = debug_event.wall_time
-      source_file = debug_event.source_file
-      self.assertEqual(source_file.host_name, socket.gethostname())
-      self.assertTrue(source_file.file_path)
-      if source_file.lines:
-        self.assertTrue(os.path.isfile(source_file.file_path))
-      source_file_paths.append(source_file.file_path)
-    # Assert the file paths are unique.
-    self.assertEqual(len(source_file_paths), len(set(source_file_paths)))
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      # Check the content of the .source_files file.
+      source_files_iter = reader.source_files_iterator()
+      source_file_paths = []
+      prev_wall_time = 1
+      for debug_event in source_files_iter:
+        self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
+        prev_wall_time = debug_event.wall_time
+        source_file = debug_event.source_file
+        self.assertEqual(source_file.host_name, socket.gethostname())
+        self.assertTrue(source_file.file_path)
+        if source_file.lines:
+          self.assertTrue(os.path.isfile(source_file.file_path))
+        source_file_paths.append(source_file.file_path)
+      # Assert the file paths are unique.
+      self.assertEqual(len(source_file_paths), len(set(source_file_paths)))
 
-    # Check the content of the .stack_frames file.
-    stack_frame_by_id = dict()  # A map from ID to stack frame.
-    stack_frames_iter = reader.stack_frames_iterator()
-    prev_wall_time = 0
-    for debug_event in stack_frames_iter:
-      self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
-      prev_wall_time = debug_event.wall_time
-      stack_frame_with_id = debug_event.stack_frame_with_id
-      stack_frame_id = stack_frame_with_id.id
-      file_line_col = stack_frame_with_id.file_line_col
-      self.assertTrue(stack_frame_id)
-      self.assertNotIn(stack_frame_id, stack_frame_by_id,
-                       "Duplicate stack frame ID: %s" % id)
-      stack_frame_by_id[stack_frame_id] = (file_line_col.file_index,
-                                           file_line_col.line,
-                                           file_line_col.func)
-      self.assertGreaterEqual(file_line_col.file_index, 0)
-      self.assertLess(file_line_col.file_index, len(source_file_paths))
-      self.assertTrue(file_line_col.line)  # Line numbers are 1-based.
-      self.assertTrue(file_line_col.func)
-    # Assert the stack frames are unique.
-    self.assertEqual(
-        len(stack_frame_by_id.values()), len(set(stack_frame_by_id.values())))
-    return stack_frame_by_id
+      # Check the content of the .stack_frames file.
+      # A map from ID to stack frame.
+      stack_frame_by_id = collections.OrderedDict()
+      stack_frames_iter = reader.stack_frames_iterator()
+      prev_wall_time = 0
+      for debug_event in stack_frames_iter:
+        self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
+        prev_wall_time = debug_event.wall_time
+        stack_frame_with_id = debug_event.stack_frame_with_id
+        stack_frame_id = stack_frame_with_id.id
+        file_line_col = stack_frame_with_id.file_line_col
+        self.assertTrue(stack_frame_id)
+        self.assertNotIn(stack_frame_id, stack_frame_by_id,
+                         "Duplicate stack frame ID: %s" % id)
+        stack_frame_by_id[stack_frame_id] = (file_line_col.file_index,
+                                             file_line_col.line,
+                                             file_line_col.func)
+        self.assertGreaterEqual(file_line_col.file_index, 0)
+        self.assertLess(file_line_col.file_index, len(source_file_paths))
+        self.assertTrue(file_line_col.line)  # Line numbers are 1-based.
+        self.assertTrue(file_line_col.func)
+      # Assert the stack frames are unique.
+      self.assertEqual(
+          len(stack_frame_by_id.values()), len(set(stack_frame_by_id.values())))
+      return stack_frame_by_id
 
   def _readAndCheckGraphsFile(self, stack_frame_by_id):
     """Read and verify the content of the .graphs debug-event file.
@@ -115,73 +117,72 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
         `list` of `str`s.
       op_types: Types of the ops that are created, as a `list` of `str`s with
         the same length as `context_ids`.
-      op_name_to_op_type: A `dict` mapping op name to op type.
+      op_name_to_op_type: An `OrderedDict` mapping op name to op type.
       op_name_to_context_id: A `dict` mapping op name to the ID of the innermost
         containing graph (context).
     """
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    graphs_iter = reader.graphs_iterator()
-    prev_wall_time = 0
-    op_types = []
-    op_name_to_op_type = dict()
-    op_name_to_context_id = dict()  # Maps op name to ID of innermost context.
-    context_ids = set()
-    symbolic_tensor_ids = set()
-    # Maps context ID to ID of directly enclosing context (`None` for
-    # outermost contexts).
-    context_id_to_outer_id = dict()
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      graphs_iter = reader.graphs_iterator()
+      prev_wall_time = 0
+      op_types = []
+      op_name_to_op_type = collections.OrderedDict()
+      op_name_to_context_id = dict()  # Maps op name to ID of innermost context.
+      context_ids = set()
+      symbolic_tensor_ids = set()
+      # Maps context ID to ID of directly enclosing context (`None` for
+      # outermost contexts).
+      context_id_to_outer_id = dict()
 
-    for debug_event in graphs_iter:
-      self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
-      prev_wall_time = debug_event.wall_time
-      # A DebugEvent in the .graphs file contains either of the two fields:
-      # - graph_op_creation for creation of a symbolic op in a graph context.
-      # - debugged_graph for information regarding the graph (context).
-      if debug_event.graph_op_creation.ByteSize():
-        graph_op_creation = debug_event.graph_op_creation
-        self.assertTrue(graph_op_creation.op_type)
-        op_types.append(graph_op_creation.op_type)
-        self.assertTrue(graph_op_creation.op_name)
-        op_name_to_op_type[
-            graph_op_creation.op_name] = graph_op_creation.op_type
-        op_name_to_context_id[
-            graph_op_creation.op_name] = graph_op_creation.graph_id
-        self.assertTrue(graph_op_creation.graph_id)
-        context_ids.add(graph_op_creation.graph_id)
-        self.assertTrue(graph_op_creation.code_location)
-        if graph_op_creation.num_outputs:
-          self.assertLen(graph_op_creation.output_tensor_ids,
-                         graph_op_creation.num_outputs)
-          # Check that all symblic tensor IDs are unique.
-          for tensor_id in graph_op_creation.output_tensor_ids:
-            self.assertNotIn(tensor_id, symbolic_tensor_ids)
-            symbolic_tensor_ids.add(tensor_id)
-        for stack_frame_id in graph_op_creation.code_location.stack_frame_ids:
-          self.assertIn(stack_frame_id, stack_frame_by_id)
-      else:
-        debugged_graph = debug_event.debugged_graph
-        if debugged_graph.outer_context_id:
-          inner_id = debugged_graph.graph_id
-          outer_id = debugged_graph.outer_context_id
-          if inner_id in context_id_to_outer_id:
-            # The outer context of a context must be always the same.
-            self.assertEqual(context_id_to_outer_id[inner_id], outer_id)
-          else:
-            context_id_to_outer_id[inner_id] = outer_id
+      for debug_event in graphs_iter:
+        self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
+        prev_wall_time = debug_event.wall_time
+        # A DebugEvent in the .graphs file contains either of the two fields:
+        # - graph_op_creation for creation of a symbolic op in a graph context.
+        # - debugged_graph for information regarding the graph (context).
+        if debug_event.graph_op_creation.ByteSize():
+          graph_op_creation = debug_event.graph_op_creation
+          self.assertTrue(graph_op_creation.op_type)
+          op_types.append(graph_op_creation.op_type)
+          self.assertTrue(graph_op_creation.op_name)
+          op_name_to_op_type[
+              graph_op_creation.op_name] = graph_op_creation.op_type
+          op_name_to_context_id[
+              graph_op_creation.op_name] = graph_op_creation.graph_id
+          self.assertTrue(graph_op_creation.graph_id)
+          context_ids.add(graph_op_creation.graph_id)
+          self.assertTrue(graph_op_creation.code_location)
+          if graph_op_creation.num_outputs:
+            self.assertLen(graph_op_creation.output_tensor_ids,
+                           graph_op_creation.num_outputs)
+            # Check that all symblic tensor IDs are unique.
+            for tensor_id in graph_op_creation.output_tensor_ids:
+              self.assertNotIn(tensor_id, symbolic_tensor_ids)
+              symbolic_tensor_ids.add(tensor_id)
+          for stack_frame_id in graph_op_creation.code_location.stack_frame_ids:
+            self.assertIn(stack_frame_id, stack_frame_by_id)
         else:
-          # This is an outermost context.
-          if debugged_graph.graph_id in context_id_to_outer_id:
-            self.assertIsNone(context_id_to_outer_id[debugged_graph.graph_id])
+          debugged_graph = debug_event.debugged_graph
+          if debugged_graph.outer_context_id:
+            inner_id = debugged_graph.graph_id
+            outer_id = debugged_graph.outer_context_id
+            if inner_id in context_id_to_outer_id:
+              # The outer context of a context must be always the same.
+              self.assertEqual(context_id_to_outer_id[inner_id], outer_id)
+            else:
+              context_id_to_outer_id[inner_id] = outer_id
           else:
-            context_id_to_outer_id[debugged_graph.graph_id] = None
+            # This is an outermost context.
+            if debugged_graph.graph_id in context_id_to_outer_id:
+              self.assertIsNone(context_id_to_outer_id[debugged_graph.graph_id])
+            else:
+              context_id_to_outer_id[debugged_graph.graph_id] = None
 
-    # If any graph is created, the graph context hierarchy must be populated.
-    # In addition, the context of each graph op must be locatable within the
-    # graph context hierarchy.
-    for context_id in op_name_to_context_id.values():
-      self.assertIn(context_id, context_id_to_outer_id)
-
-    return context_ids, op_types, op_name_to_op_type, op_name_to_context_id
+      # If any graph is created, the graph context hierarchy must be populated.
+      # In addition, the context of each graph op must be locatable within the
+      # graph context hierarchy.
+      for context_id in op_name_to_context_id.values():
+        self.assertIn(context_id, context_id_to_outer_id)
+      return context_ids, op_types, op_name_to_op_type, op_name_to_context_id
 
   def _readAndCheckExecutionFile(self, dump_root=None):
     """Read and verify the content of the .execution debug-event file.
@@ -204,31 +205,30 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
         output tensor slot of the executed op or Function.
     """
     dump_root = self.dump_root if dump_root is None else dump_root
-    reader = debug_events_reader.DebugEventsReader(dump_root)
-    execution_iter = reader.execution_iterator()
-    prev_wall_time = 1
-    executed_op_types = []
-    input_tensor_ids = []
-    output_tensor_ids = []
-    tensor_debug_modes = []
-    tensor_values = []
-    for debug_event in execution_iter:
-      self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
-      prev_wall_time = debug_event.wall_time
-      execution = debug_event.execution
-      executed_op_types.append(execution.op_type)
-      input_tensor_ids.append(execution.input_tensor_ids)
-      output_tensor_ids.append(execution.output_tensor_ids)
-      tensor_debug_modes.append(execution.tensor_debug_mode)
-      tensor_values.append([
-          tensor_util.MakeNdarray(tensor_proto)
-          for tensor_proto in execution.tensor_protos
-      ])
-
-    # TODO(cais): When tensor debug modes other than NO_TENSOR is supported,
-    # return tensor_values as well.
-    return (executed_op_types, input_tensor_ids, output_tensor_ids,
-            tensor_debug_modes, tensor_values)
+    with debug_events_reader.DebugEventsReader(dump_root) as reader:
+      execution_iter = reader.execution_iterator()
+      prev_wall_time = 1
+      executed_op_types = []
+      input_tensor_ids = []
+      output_tensor_ids = []
+      tensor_debug_modes = []
+      tensor_values = []
+      for debug_event in execution_iter:
+        self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
+        prev_wall_time = debug_event.wall_time
+        execution = debug_event.execution
+        executed_op_types.append(execution.op_type)
+        input_tensor_ids.append(execution.input_tensor_ids)
+        output_tensor_ids.append(execution.output_tensor_ids)
+        tensor_debug_modes.append(execution.tensor_debug_mode)
+        tensor_values.append([
+            tensor_util.MakeNdarray(tensor_proto)
+            for tensor_proto in execution.tensor_protos
+        ])
+      # TODO(cais): When tensor debug modes other than NO_TENSOR is supported,
+      # return tensor_values as well.
+      return (executed_op_types, input_tensor_ids, output_tensor_ids,
+              tensor_debug_modes, tensor_values)
 
   def _readAndCheckGraphExecutionTracesFile(self, context_ids):
     """Read & verify the content of the .graph_execution_trace debug-event file.
@@ -247,29 +247,29 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
       tensor_values: Tensor values or their concise summaries, depending on
         TensorDebugMode.
     """
-    reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    graph_execution_traces_iter = reader.graph_execution_traces_iterator()
-    op_names = []
-    device_names = []
-    output_slots = []
-    tensor_values = []
-    for debug_event in graph_execution_traces_iter:
-      self.assertGreaterEqual(debug_event.wall_time, 0)
-      graph_execution_trace = debug_event.graph_execution_trace
-      op_names.append(graph_execution_trace.op_name)
-      self.assertTrue(graph_execution_trace.device_name)
-      device_names.append(graph_execution_trace.device_name)
-      # All the ops in the graph have only one output.
-      self.assertTrue(graph_execution_trace.tfdbg_context_id)
-      self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids)
-      output_slots.append(graph_execution_trace.output_slot)
-      dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype)
-      if (dtype.is_numpy_compatible and
-          dtype._type_enum != types_pb2.DT_STRING):  # pylint:disable=protected-access
-        # TODO(cais): Figure out how to properly convert string tensor proto to
-        # numpy representation.
-        tensor_values.append(
-            tensor_util.MakeNdarray(graph_execution_trace.tensor_proto))
-      else:
-        tensor_values.append(None)
-    return op_names, device_names, output_slots, tensor_values
+    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
+      graph_execution_traces_iter = reader.graph_execution_traces_iterator()
+      op_names = []
+      device_names = []
+      output_slots = []
+      tensor_values = []
+      for debug_event in graph_execution_traces_iter:
+        self.assertGreaterEqual(debug_event.wall_time, 0)
+        graph_execution_trace = debug_event.graph_execution_trace
+        op_names.append(graph_execution_trace.op_name)
+        self.assertTrue(graph_execution_trace.device_name)
+        device_names.append(graph_execution_trace.device_name)
+        # All the ops in the graph have only one output.
+        self.assertTrue(graph_execution_trace.tfdbg_context_id)
+        self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids)
+        output_slots.append(graph_execution_trace.output_slot)
+        dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype)
+        if (dtype.is_numpy_compatible and
+            dtype._type_enum != types_pb2.DT_STRING):  # pylint:disable=protected-access
+          # TODO(cais): Figure out how to properly convert string tensor proto
+          # to numpy representation.
+          tensor_values.append(
+              tensor_util.MakeNdarray(graph_execution_trace.tensor_proto))
+        else:
+          tensor_values.append(None)
+      return op_names, device_names, output_slots, tensor_values

From 992c03f0b5766e6570af9e2ef7f8cb6e89dcf000 Mon Sep 17 00:00:00 2001
From: Yanhui Liang 
Date: Tue, 3 Dec 2019 16:29:56 -0800
Subject: [PATCH 257/279] Reuse the `_MultiIOSubclassModel` class in the unit
 tests of subclass Model.

PiperOrigin-RevId: 283647635
Change-Id: I52fc8e1b5c88719debfdb2fc0ee18ef0b5a84d70
---
 .../keras/model_subclassing_compiled_test.py  | 26 ++++++++-----
 .../python/keras/model_subclassing_test.py    |  8 ++--
 .../keras/model_subclassing_test_util.py      | 39 ++++++++-----------
 tensorflow/python/keras/testing_utils.py      |  4 +-
 4 files changed, 39 insertions(+), 38 deletions(-)

diff --git a/tensorflow/python/keras/model_subclassing_compiled_test.py b/tensorflow/python/keras/model_subclassing_compiled_test.py
index 180e8c8b735..bf27b3bf8a7 100644
--- a/tensorflow/python/keras/model_subclassing_compiled_test.py
+++ b/tensorflow/python/keras/model_subclassing_compiled_test.py
@@ -64,7 +64,7 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
     num_samples = 1000
     input_dim = 50
 
-    model = model_util.MultiIOTestModel(
+    model = model_util.get_multi_io_subclass_model(
         num_classes=num_classes, use_dp=True, use_bn=True)
     model.compile(
         loss='mse',
@@ -111,7 +111,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
     num_samples = 100
     input_dim = 50
 
-    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model = model_util.get_multi_io_subclass_model(
+        num_classes=num_classes, use_bn=True)
 
     x1 = np.ones((num_samples, input_dim))
     x2 = np.ones((num_samples, input_dim))
@@ -211,7 +212,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
     y1 = np.zeros((num_samples, num_classes[0]))
     y2 = np.zeros((num_samples, num_classes[1]))
 
-    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model = model_util.get_multi_io_subclass_model(
+        num_classes=num_classes, use_bn=True)
     model.compile(
         loss='mse',
         optimizer='rmsprop',
@@ -224,7 +226,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
     model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0,
               validation_data=([x1, x2], [y1, y2]))
 
-    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model = model_util.get_multi_io_subclass_model(
+        num_classes=num_classes, use_bn=True)
     model.compile(
         loss='mse',
         optimizer='rmsprop',
@@ -246,7 +249,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
     y1 = np.zeros((num_samples, num_classes[0]))
     y2 = np.zeros((num_samples, num_classes[1]))
 
-    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model = model_util.get_multi_io_subclass_model(
+        num_classes=num_classes, use_bn=True)
     model.compile(
         loss='mse',
         optimizer='rmsprop',
@@ -255,10 +259,12 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
     model.evaluate([x1, x2], [y1, y2])
     model.test_on_batch([x1, x2], [y1, y2])
 
-    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model = model_util.get_multi_io_subclass_model(
+        num_classes=num_classes, use_bn=True)
     model.predict([x1, x2])
 
-    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model = model_util.get_multi_io_subclass_model(
+        num_classes=num_classes, use_bn=True)
     model.predict_on_batch([x1, x2])
 
   def test_saving(self):
@@ -271,7 +277,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
     y1 = np.zeros((num_samples, num_classes[0]))
     y2 = np.zeros((num_samples, num_classes[1]))
 
-    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model = model_util.get_multi_io_subclass_model(
+        num_classes=num_classes, use_bn=True)
     model.compile(
         loss='mse',
         optimizer='rmsprop',
@@ -286,7 +293,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
       hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5')
       model.save_weights(hdf5_format_name)
 
-    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model = model_util.get_multi_io_subclass_model(
+        num_classes=num_classes, use_bn=True)
 
     if h5py is not None:
       with self.assertRaises(ValueError):
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index bbab5637a89..a4b8ac92b03 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -313,7 +313,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase):
     batch_size = None
     num_samples = 1000
     input_dim = 50
-    model = model_util.MultiIOTestModel()
+    model = model_util.get_multi_io_subclass_model()
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
                                      'has not been built.'))
@@ -345,7 +345,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase):
     self.assertTrue('Trainable params: 356' in print_fn.contents)
 
     # Multi-io
-    model = model_util.MultiIOTestModel(
+    model = model_util.get_multi_io_subclass_model(
         num_classes=(5, 6), use_bn=True, use_dp=True)
     model._set_inputs([np.ones((3, 4)),
                        np.ones((3, 4))])  # need to build model first
@@ -508,7 +508,7 @@ class GraphSpecificModelSubclassingTests(test.TestCase):
     input_dim = 50
 
     with self.cached_session():
-      model = model_util.MultiIOTestModel(
+      model = model_util.get_multi_io_subclass_model(
           num_classes=num_classes, use_dp=True, use_bn=True)
       model.compile(loss='mse', optimizer='rmsprop')
 
@@ -595,7 +595,7 @@ class GraphSpecificModelSubclassingTests(test.TestCase):
     input_dim = 50
 
     with self.cached_session():
-      model = model_util.MultiIOTestModel(
+      model = model_util.get_multi_io_subclass_model(
           num_classes=num_classes, use_dp=True, use_bn=True)
       model.compile(loss='mse', optimizer='rmsprop')
 
diff --git a/tensorflow/python/keras/model_subclassing_test_util.py b/tensorflow/python/keras/model_subclassing_test_util.py
index 0f07c716b80..cf627b984a1 100644
--- a/tensorflow/python/keras/model_subclassing_test_util.py
+++ b/tensorflow/python/keras/model_subclassing_test_util.py
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python import keras
+from tensorflow.python.keras import testing_utils
 
 
 # pylint: disable=missing-docstring,not-callable
@@ -62,31 +63,23 @@ class SimpleConvTestModel(keras.Model):
     return self.dense1(x)
 
 
-class MultiIOTestModel(keras.Model):
+def get_multi_io_subclass_model(use_bn=False, use_dp=False, num_classes=(2, 3)):
+  """Creates MultiIOModel for the tests of subclass model."""
+  shared_layer = keras.layers.Dense(32, activation='relu')
+  branch_a = [shared_layer]
+  if use_dp:
+    branch_a.append(keras.layers.Dropout(0.5))
+  branch_a.append(keras.layers.Dense(num_classes[0], activation='softmax'))
 
-  def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)):
-    super(MultiIOTestModel, self).__init__(name='test_model')
-    self.use_bn = use_bn
-    self.use_dp = use_dp
-    self.num_classes = num_classes
+  branch_b = [shared_layer]
+  if use_bn:
+    branch_b.append(keras.layers.BatchNormalization())
+  branch_b.append(keras.layers.Dense(num_classes[1], activation='softmax'))
 
-    self.dense1 = keras.layers.Dense(32, activation='relu')
-    self.dense2 = keras.layers.Dense(num_classes[0], activation='softmax')
-    self.dense3 = keras.layers.Dense(num_classes[1], activation='softmax')
-    if use_dp:
-      self.dp = keras.layers.Dropout(0.5)
-    if use_bn:
-      self.bn = keras.layers.BatchNormalization()
-
-  def call(self, inputs):
-    x1, x2 = inputs
-    x1 = self.dense1(x1)
-    x2 = self.dense1(x2)
-    if self.use_dp:
-      x1 = self.dp(x1)
-    if self.use_bn:
-      x2 = self.bn(x2)
-    return [self.dense2(x1), self.dense3(x2)]
+  model = (
+      testing_utils._MultiIOSubclassModel(   # pylint: disable=protected-access
+          branch_a, branch_b, name='test_model'))
+  return model
 
 
 class NestedTestModel1(keras.Model):
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index e4c2406399f..146807028cb 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -608,8 +608,8 @@ class _MultiIOSubclassModel(keras.Model):
   """Multi IO Keras subclass model."""
 
   def __init__(self, branch_a, branch_b, shared_input_branch=None,
-               shared_output_branch=None):
-    super(_MultiIOSubclassModel, self).__init__()
+               shared_output_branch=None, name=None):
+    super(_MultiIOSubclassModel, self).__init__(name=name)
     self._shared_input_branch = shared_input_branch
     self._branch_a = branch_a
     self._branch_b = branch_b

From e79efe524efad3b082c2a7d57e06433eaa8b9f03 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 16:34:51 -0800
Subject: [PATCH 258/279] Provides a better example for `StackedRNNCells`. The
 current example does not demonstrate the use of the class.

PiperOrigin-RevId: 283648537
Change-Id: I30bb64c64cce1fe3b4bf7c57256be4d04e90f785
---
 tensorflow/python/keras/layers/recurrent.py | 17 ++++++++++-------
 1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 87a99f49164..eb8f43fd993 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -66,14 +66,17 @@ class StackedRNNCells(Layer):
   Examples:
 
   ```python
-  cells = [
-      keras.layers.LSTMCell(output_dim),
-      keras.layers.LSTMCell(output_dim),
-      keras.layers.LSTMCell(output_dim),
-  ]
+  batch_size = 3
+  sentence_max_length = 5
+  n_features = 2
+  new_shape = (batch_size, sentence_max_length, n_features)
+  x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)
 
-  inputs = keras.Input((timesteps, input_dim))
-  x = keras.layers.RNN(cells)(inputs)
+  rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)]
+  stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells)
+  lstm_layer = tf.keras.layers.RNN(stacked_lstm)
+
+  result = lstm_layer(x)
   ```
   """
 

From 4b2f4f4d25a75593bdeaf993b8cdcfc7055a7a55 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 16:40:54 -0800
Subject: [PATCH 259/279] Fix a documentation issue in Attention with incorrect
 name being used in example.

PiperOrigin-RevId: 283649606
Change-Id: Id0e53be00126ec4825efbdb4c59da791185e3275
---
 tensorflow/python/keras/layers/dense_attention.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/keras/layers/dense_attention.py b/tensorflow/python/keras/layers/dense_attention.py
index 054e840f48c..ff5e62d0d8b 100644
--- a/tensorflow/python/keras/layers/dense_attention.py
+++ b/tensorflow/python/keras/layers/dense_attention.py
@@ -243,7 +243,7 @@ class Attention(BaseDenseAttention):
   # Query embeddings of shape [batch_size, Tq, dimension].
   query_embeddings = token_embedding(query_input)
   # Value embeddings of shape [batch_size, Tv, dimension].
-  value_embeddings = token_embedding(query_input)
+  value_embeddings = token_embedding(value_input)
 
   # CNN layer.
   cnn_layer = tf.keras.layers.Conv1D(

From f412457edecee343faf169edbe720d186156d3e1 Mon Sep 17 00:00:00 2001
From: Lei Zhang 
Date: Tue, 3 Dec 2019 16:43:40 -0800
Subject: [PATCH 260/279] [spirv] Add spv.GroupNonUniformBallot

This CL also did the following cleanup:
- Moved the test for spv.SubgroupBallotKHR to its own file
- Wrapped generated canonicalization patterns in anonymous namespace
- Updated header comments in SPVOps.td

PiperOrigin-RevId: 283650091
Change-Id: Ia10cc3d8787a05c07f1c8e84b5d9d43cb4224f4a
---
 third_party/mlir/BUILD                        |  1 +
 .../include/mlir/Dialect/SPIRV/SPIRVBase.td   |  7 +-
 .../mlir/Dialect/SPIRV/SPIRVNonUniformOps.td  | 78 +++++++++++++++++++
 .../include/mlir/Dialect/SPIRV/SPIRVOps.td    | 12 +--
 .../mlir/lib/Dialect/SPIRV/SPIRVOps.cpp       | 40 ++++++++++
 5 files changed, 131 insertions(+), 7 deletions(-)
 create mode 100644 third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td

diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 17e7c5b58c9..a522065d72a 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -953,6 +953,7 @@ filegroup(
         "include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVGroupOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td",
+        "include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVOps.td",
         "include/mlir/Dialect/SPIRV/SPIRVStructureOps.td",
         ":OpBaseTdFiles",
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index bfb7497aada..2ee8f3bdd43 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -953,7 +953,9 @@ class SPV_ScalarOrVectorOf :
 def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
 def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
 
-def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>;
+class SPV_Vec4 : VectorOfLengthAndType<[4], [type]>;
+def SPV_IntVec4 : SPV_Vec4;
+def SPV_I32Vec4 : SPV_Vec4;
 
 // TODO(antiagainst): Use a more appropriate way to model optional operands
 class SPV_Optional : Variadic;
@@ -1109,6 +1111,7 @@ def SPV_OC_OpReturn                 : I32EnumAttrCase<"OpReturn", 253>;
 def SPV_OC_OpReturnValue            : I32EnumAttrCase<"OpReturnValue", 254>;
 def SPV_OC_OpUnreachable            : I32EnumAttrCase<"OpUnreachable", 255>;
 def SPV_OC_OpModuleProcessed        : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPV_OC_OpGroupNonUniformBallot  : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
 def SPV_OC_OpSubgroupBallotKHR      : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 
 def SPV_OpcodeAttr :
@@ -1150,7 +1153,7 @@ def SPV_OpcodeAttr :
       SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
       SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
       SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
-      SPV_OC_OpSubgroupBallotKHR
+      SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR
       ]> {
     let cppNamespace = "::mlir::spirv";
 }
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
new file mode 100644
index 00000000000..a37f5b576fd
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
@@ -0,0 +1,78 @@
+//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to
+// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_NON_UNIFORM_OPS
+#define SPIRV_NON_UNIFORM_OPS
+
+// -----
+
+def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
+  let summary = [{
+    Returns a bitfield value combining the Predicate value from all
+    invocations in the group that execute the same dynamic instance of this
+    instruction. The bit is set to one if the corresponding invocation is
+    active and the Predicate for that invocation evaluated to true;
+    otherwise, it is set to zero.
+  }];
+
+  let description = [{
+    Result Type  must be a vector of four components of integer type scalar,
+    whose Signedness operand is 0.
+
+    Result is a set of bitfields where the first invocation is represented
+    in the lowest bit of the first vector component and the last (up to the
+    size of the group) is the higher bit number of the last bitmask needed
+    to represent all bits of the group invocations.
+
+    Execution must be Workgroup or Subgroup Scope.
+
+    Predicate must be a Boolean type.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    scope ::= `"Workgroup"` | `"Subgroup"`
+    non-uniform-ballot-op ::= ssa-id `=` `spv.GroupNonUniformBallot` scope
+                              ssa-use `:` `vector` `<` 4 `x` `integer-type` `>`
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.GroupNonUniformBallot "SubGroup" %predicate : vector<4xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$execution_scope,
+    SPV_Bool:$predicate
+  );
+
+  let results = (outs
+    SPV_IntVec4:$result
+  );
+}
+
+// -----
+
+#endif // SPIRV_NON_UNIFORM_OPS
+
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 178db0add4e..149c2359fda 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -20,11 +20,12 @@
 //
 //===----------------------------------------------------------------------===//
 
-// Note that for each op in this file, we use a tool to automatically generate
-// certain sections in its definition: basic structure, summary, description.
-// So modifications to these sections will not be respected. Modifications to
-// op traits, arguments, results, and sections after the results are retained.
-// Besides, ops in this file must be separated via the '// -----' marker.
+// Note that for each op in this file and the included files for specific op
+// categories, we use a tool to automatically generate certain sections in its
+// definition: basic structure, summary, description. So modifications to these
+// sections will not be respected. Modifications to op traits, arguments,
+// results, and sections after the results are retained. Besides, ops must be
+// separated via the '// -----' marker.
 
 #ifndef SPIRV_OPS
 #define SPIRV_OPS
@@ -37,6 +38,7 @@ include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
 include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
+include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
 include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
 
 // -----
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6e115f7ba76..89abbe894e6 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -385,7 +385,9 @@ static inline bool isMergeBlock(Block &block) {
 // TableGen'erated canonicalizers
 //===----------------------------------------------------------------------===//
 
+namespace {
 #include "SPIRVCanonicalization.inc"
+}
 
 //===----------------------------------------------------------------------===//
 // Common parsers and printers
@@ -1551,6 +1553,44 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.GroupNonUniformBallotOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
+                                                OperationState &state) {
+  spirv::Scope executionScope;
+  OpAsmParser::OperandType operandInfo;
+  Type resultType;
+  IntegerType i1Type = parser.getBuilder().getI1Type();
+  if (parseEnumAttribute(executionScope, parser, state,
+                         kExecutionScopeAttrName) ||
+      parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
+      parser.resolveOperand(operandInfo, i1Type, state.operands))
+    return failure();
+
+  return parser.addTypeToList(resultType, state.types);
+}
+
+static void print(spirv::GroupNonUniformBallotOp ballotOp,
+                  OpAsmPrinter &printer) {
+  printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
+          << stringifyScope(ballotOp.execution_scope()) << "\" ";
+  printer.printOperand(ballotOp.predicate());
+  printer << " : " << ballotOp.getType();
+}
+
+static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
+  // TODO(antiagainst): check the result integer type's signedness bit is 0.
+
+  spirv::Scope scope = ballotOp.execution_scope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return ballotOp.emitOpError(
+        "execution scope must be 'Workgroup' or 'Subgroup'");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.IAdd
 //===----------------------------------------------------------------------===//

From 56221776d3f705fb5f9574d11d0ec604e9a2738c Mon Sep 17 00:00:00 2001
From: Peter Hawkins 
Date: Tue, 3 Dec 2019 16:54:29 -0800
Subject: [PATCH 261/279] [XLA] Refactor Executable::ExecuteAsyncOnStream.

Change implementations of Executable to always implement the overload that takes a std::vector>. Make the non-owning version a wrapper around the maybe-owning version.

Simplification in preparation for plumbing buffer donation into JAX. This change is also a necessary preparatory step for implementing buffer donation on CPU and GPU.

PiperOrigin-RevId: 283651970
Change-Id: I8766ae34071727dab43d1ab2597ae2a4a19f11d3
---
 tensorflow/compiler/xla/service/cpu/BUILD     |  6 +++
 .../xla/service/cpu/cpu_executable.cc         | 40 ++++++++++++------
 .../compiler/xla/service/cpu/cpu_executable.h | 14 ++++---
 tensorflow/compiler/xla/service/executable.cc | 38 +++++++++++++----
 tensorflow/compiler/xla/service/executable.h  | 10 ++---
 .../xla/service/gpu/gpu_executable.cc         | 39 ++++++++++--------
 .../compiler/xla/service/gpu/gpu_executable.h |  9 +---
 .../service/hlo_input_output_alias_config.cc  |  3 +-
 .../compiler/xla/service/interpreter/BUILD    |  4 ++
 .../xla/service/interpreter/executable.cc     | 41 +++++++++++++++----
 .../xla/service/interpreter/executable.h      |  4 +-
 .../xla/service/maybe_owning_device_memory.cc |  3 +-
 .../xla/service/maybe_owning_device_memory.h  |  2 +-
 13 files changed, 144 insertions(+), 69 deletions(-)

diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 229827c77c8..bec66aea27f 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -242,9 +242,15 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
         "//tensorflow/compiler/xla/service:hlo_execution_profile",
         "//tensorflow/compiler/xla/service:logical_buffer",
+        "//tensorflow/compiler/xla/service:maybe_owning_device_memory",
         "//tensorflow/compiler/xla/service:shaped_buffer",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:platform_port",
+        "//tensorflow/core/platform:types",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/stream_executor:device_memory_allocator",
         "//tensorflow/stream_executor/host:host_stream",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 9b79e8ca8d7..d19cf4fb015 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -32,6 +32,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
 #include "tensorflow/compiler/xla/shape_tree.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -44,6 +45,7 @@ limitations under the License.
 #include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/device_memory_allocator.h"
 #include "tensorflow/stream_executor/host/host_stream.h"
 
 namespace xla {
@@ -73,11 +75,12 @@ CpuExecutable::CpuExecutable(
           << reinterpret_cast(compute_function_);
 }
 
-StatusOr,
-                   std::vector>>
+StatusOr,
+                    std::vector,
+                    std::vector>>
 CpuExecutable::CreateBufferTable(
     se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
-    absl::Span arguments) {
+    std::vector> arguments) {
   std::vector unowning_buffers(
       assignment_->Allocations().size());
   std::vector owning_buffers(
@@ -91,8 +94,9 @@ CpuExecutable::CreateBufferTable(
     VLOG(3) << allocation.ToString();
 
     if (allocation.is_entry_computation_parameter()) {
-      unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
-          allocation.param_shape_index());
+      unowning_buffers[i] = arguments[allocation.parameter_number()]
+                                .element(allocation.param_shape_index())
+                                .AsDeviceMemoryBase();
       CHECK_EQ(allocation.size(), unowning_buffers[i].size())
           << "Size mismatch on param " << allocation.parameter_number()
           << " at shape index " << allocation.param_shape_index().ToString();
@@ -134,7 +138,17 @@ CpuExecutable::CreateBufferTable(
                       assignment_->GetUniqueTopLevelOutputSlice());
   VLOG(3) << "result index: " << result_slice.index();
 
-  return {{std::move(unowning_buffers), std::move(owning_buffers)}};
+  std::vector buffers_to_free;
+  for (ShapeTree& argument : arguments) {
+    for (std::pair& buffer : argument) {
+      auto maybe_owning_buffer = buffer.second.Release();
+      if (maybe_owning_buffer) {
+        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
+      }
+    }
+  }
+  return std::make_tuple(std::move(unowning_buffers), std::move(owning_buffers),
+                         std::move(buffers_to_free));
 }
 
 Status CpuExecutable::ExecuteComputeFunction(
@@ -268,9 +282,9 @@ StatusOr CpuExecutable::CreateResultShapedBuffer(
   return std::move(result_buffer);
 }
 
-StatusOr CpuExecutable::ExecuteAsyncOnStream(
+StatusOr CpuExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
+    std::vector> arguments,
     HloExecutionProfile* hlo_execution_profile) {
   if (GetRootValueSet().IsAmbiguous()) {
     return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -283,7 +297,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
     for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
       const Shape& expected_shape =
           entry_comp->parameter_instruction(i)->shape();
-      const Shape& actual_shape = arguments[i]->on_device_shape();
+      const Shape& actual_shape = arguments[i].shape();
       CHECK(expected_shape == actual_shape) << absl::StreamFormat(
           "Shape mismatch on argument %d.  Expected %s, but was %s.", i,
           expected_shape.ToString(/*print_layout=*/true),
@@ -297,10 +311,11 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
   se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
   std::vector owning_buffers;
   std::vector unowning_buffers;
+  std::vector buffers_to_release;
   TF_ASSIGN_OR_RETURN(
-      std::tie(unowning_buffers, owning_buffers),
+      std::tie(unowning_buffers, owning_buffers, buffers_to_release),
       CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
-                        arguments));
+                        std::move(arguments)));
 
   TF_ASSIGN_OR_RETURN(
       ScopedShapedBuffer result,
@@ -339,7 +354,8 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream(
                        std::move(owning_buffers)),
                    hlo_execution_profile});
 
-  return std::move(result);
+  return ExecutionOutput(std::move(result), std::move(buffers_to_release), {},
+                         se::OwningDeviceMemory());
 }
 
 /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 37af630a2d9..6f8a7c3315a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -55,9 +55,9 @@ class CpuExecutable : public Executable {
                 std::unique_ptr hlo_profile_index_map);
   ~CpuExecutable() override {}
 
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      absl::Span arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
   // This should be called after set_ir_module_string.
@@ -96,11 +96,15 @@ class CpuExecutable : public Executable {
   //    allocated by this routine.  This routine allocates buffers for temporary
   //    storage and the live-out buffer into which the computation writes it
   //    result.
-  StatusOr,
-                     std::vector>>
+  //
+  //  - buffers_to_free: buffers whose ownership was donated by the caller that
+  //    are to be freed by the caller.
+  StatusOr,
+                      std::vector,
+                      std::vector>>
   CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
                     int device_ordinal,
-                    absl::Span arguments);
+                    std::vector> arguments);
 
   // Calls the generated function performing the computation with the given
   // arguments using the supplied buffers.
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index c21721c9339..9ece6172d12 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/debug_options_flags.h"
 #include "tensorflow/compiler/xla/service/dump.h"
 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -43,9 +44,36 @@ StatusOr Executable::ExecuteOnStream(
   return result;
 }
 
+static ShapeTree MakeMaybeOwningDeviceMemoryTree(
+    const ShapedBuffer& shaped_buffer) {
+  ShapeTree result(shaped_buffer.on_device_shape());
+  auto in_it = shaped_buffer.buffers().begin();
+  auto out_it = result.begin();
+  for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) {
+    DCHECK(out_it != result.end());
+    out_it->second = MaybeOwningDeviceMemory(in_it->second);
+  }
+  return result;
+}
+
+StatusOr Executable::ExecuteAsyncOnStream(
+    const ServiceExecutableRunOptions* run_options,
+    absl::Span arguments,
+    HloExecutionProfile* hlo_execution_profile) {
+  std::vector> args(arguments.size());
+  auto out_it = args.begin();
+  for (const ShapedBuffer* arg : arguments) {
+    *out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
+  }
+  TF_ASSIGN_OR_RETURN(ExecutionOutput out,
+                      ExecuteAsyncOnStream(run_options, std::move(args),
+                                           hlo_execution_profile));
+  return out.ConsumeResult();
+}
+
 StatusOr Executable::ExecuteOnStream(
     const ServiceExecutableRunOptions* run_options,
-    std::vector> arguments,
+    std::vector> arguments,
     HloExecutionProfile* hlo_execution_profile) {
   StatusOr result = ExecuteAsyncOnStream(
       run_options, std::move(arguments), hlo_execution_profile);
@@ -55,14 +83,6 @@ StatusOr Executable::ExecuteOnStream(
   return result;
 }
 
-StatusOr Executable::ExecuteAsyncOnStream(
-    const ServiceExecutableRunOptions* /*run_options*/,
-    std::vector> /*arguments*/,
-    HloExecutionProfile* /*hlo_execution_profile*/) {
-  return Unimplemented(
-      "MaybeOwningDeviceMemory version of overload is not implemented ");
-}
-
 StatusOr> Executable::ExecuteOnStreams(
     absl::Span run_options,
     absl::Span> arguments) {
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 971dab95bfd..496599e7aaf 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -160,22 +160,22 @@ class Executable {
   // If the hlo_execution_profile is provided as non-nullptr, profiling will be
   // enabled. Note that profiling is tricky to use correctly, as the profiling
   // objects (when they exist) must out-live the task.
-  virtual StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
       absl::Span arguments,
-      HloExecutionProfile* hlo_execution_profile) = 0;
+      HloExecutionProfile* hlo_execution_profile);
 
   // Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to
   // complete.
   StatusOr ExecuteOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile);
 
   virtual StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector> arguments,
-      HloExecutionProfile* hlo_execution_profile);
+      std::vector> arguments,
+      HloExecutionProfile* hlo_execution_profile) = 0;
 
   // Same as ExecuteOnStream(), but runs this executable on multiple
   // streams. arguments[i] contains the arguments to the execution on
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 99bc0f7fee0..93af1cd995e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -299,11 +299,14 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
   return &module_globals_.emplace(executor, std::move(globals)).first->second;
 }
 
-StatusOr GpuExecutable::Execute(
+StatusOr GpuExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
-    HloExecutionProfile* hlo_execution_profile, bool block_host_until_done) {
-  se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+    std::vector> arguments,
+    HloExecutionProfile* hlo_execution_profile) {
+  se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator();
+  // Force synchronous execution if the allocator requires it.
+  const bool block_host_until_done =
+      !memory_allocator->AllowsAsynchronousDeallocation();
 
   if (GetRootValueSet().IsAmbiguous()) {
     return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -334,7 +337,9 @@ StatusOr GpuExecutable::Execute(
       if (allocation.is_entry_computation_parameter()) {
         auto param_no = allocation.parameter_number();
         se::DeviceMemoryBase buffer =
-            arguments[param_no]->buffer(allocation.param_shape_index());
+            arguments[param_no]
+                .element(allocation.param_shape_index())
+                .AsDeviceMemoryBase();
 
         // All top-level buffers and sub-buffers must have an explicit, non-null
         // pointer, except for zero-sized buffers, which may be null.
@@ -423,19 +428,17 @@ StatusOr GpuExecutable::Execute(
       }));
   TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
 
-  return std::move(shaped_buffer);
-}
-
-StatusOr GpuExecutable::ExecuteAsyncOnStream(
-    const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
-    HloExecutionProfile* hlo_execution_profile) {
-  se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
-  // Force synchronous execution if the allocator requires it.
-  bool block_host_until_done =
-      !memory_allocator->AllowsAsynchronousDeallocation();
-  return Execute(run_options, arguments, hlo_execution_profile,
-                 block_host_until_done);
+  std::vector buffers_to_free;
+  for (ShapeTree& argument : arguments) {
+    for (std::pair& buffer : argument) {
+      auto maybe_owning_buffer = buffer.second.Release();
+      if (maybe_owning_buffer) {
+        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
+      }
+    }
+  }
+  return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free),
+                         {}, {});
 }
 
 const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 66f86d768be..51e86a9f8ee 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -82,9 +82,9 @@ class GpuExecutable : public Executable {
 
   // ExecuteAsyncOnStream will fail if the compute capability of the stream
   // doesn't match the compute capability passed to this object's constructor.
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      absl::Span arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
   std::shared_ptr GetBufferAssignment() const {
@@ -92,11 +92,6 @@ class GpuExecutable : public Executable {
   }
 
  private:
-  StatusOr Execute(
-      const ServiceExecutableRunOptions* run_options,
-      absl::Span arguments,
-      HloExecutionProfile* hlo_execution_profile, bool block_host_until_done);
-
   // If `block_host_until_done` is false, execution will not block the host
   // until the kernels have completed. This is used as an optimization for
   // clients, such as Tensorflow, that use a single stream of execution for
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
index 1c5b166a801..3e82e3271bb 100644
--- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
@@ -151,7 +151,8 @@ absl::optional HloInputOutputAliasConfig::GetAliasedOutput(
 absl::optional
 HloInputOutputAliasConfig::GetAliasedParameter(
     const ShapeIndex& output_index) const {
-  CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
+  CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
+      << ToString() << " " << alias_.shape().ToString() << " " << output_index;
   return alias_.element(output_index);
 }
 
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 3073c68c975..552c8eb1ae5 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -89,10 +89,14 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo_evaluator",
         "//tensorflow/compiler/xla/service:hlo_execution_profile",
         "//tensorflow/compiler/xla/service:hlo_module_config",
+        "//tensorflow/compiler/xla/service:maybe_owning_device_memory",
         "//tensorflow/compiler/xla/service:shaped_buffer",
         "//tensorflow/compiler/xla/service:transfer_manager",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:types",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 0dab86d986c..f82a439fdb0 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/interpreter/executor.h"
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
 #include "tensorflow/compiler/xla/service/transfer_manager.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -39,24 +40,39 @@ namespace interpreter {
 InterpreterExecutable::InterpreterExecutable(
     std::unique_ptr hlo_module,
     std::unique_ptr evaluator)
-    : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr,
+    : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
                  /*hlo_profile_index_map=*/nullptr),
       evaluator_(std::move(evaluator)) {}
 
 InterpreterExecutable::~InterpreterExecutable() {}
 
-StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
+StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    absl::Span arguments,
+    std::vector> arguments,
     HloExecutionProfile* hlo_execution_profile) {
   se::Stream* stream = run_options->stream();
   se::StreamExecutor* executor = stream->parent();
   const se::Platform* platform = executor->platform();
 
+  // Convert the ShapeTree to a ShapedBuffer. We do this so we can call
+  // TransferManager methods below.
+  std::vector argument_buffers;
+  argument_buffers.reserve(arguments.size());
+  for (const ShapeTree& arg : arguments) {
+    argument_buffers.push_back(ShapedBuffer(arg.shape(), arg.shape(),
+                                            /*platform=*/nullptr,
+                                            /*device_ordinal=*/0));
+    auto in_it = arg.begin();
+    auto out_it = argument_buffers.back().buffers().begin();
+    for (; in_it != arg.end(); ++in_it, ++out_it) {
+      out_it->second = in_it->second.AsDeviceMemoryBase();
+    }
+  }
+
   VLOG(1) << "Execute " << module().name();
   if (VLOG_IS_ON(2)) {
-    for (const auto& a : arguments) {
-      VLOG(2) << "-- argument " << *a;
+    for (const auto& a : argument_buffers) {
+      VLOG(2) << "-- argument " << a;
     }
   }
 
@@ -71,7 +87,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
   // Check that the args have the right shape.
   for (int64 i = 0; i < computation->num_parameters(); ++i) {
     const auto& expected_shape = computation->parameter_instruction(i)->shape();
-    const auto& actual_shape = arguments[i]->on_device_shape();
+    const auto& actual_shape = argument_buffers[i].on_device_shape();
     if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
                                                    actual_shape)) {
       return InvalidArgument(
@@ -90,7 +106,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
   for (int64 p = 0; p < computation->num_parameters(); ++p) {
     TF_ASSIGN_OR_RETURN(Literal arg_literal,
                         transfer_manager->TransferLiteralFromDevice(
-                            run_options->stream(), *arguments[p]));
+                            run_options->stream(), argument_buffers[p]));
     arg_literals.push_back(std::move(arg_literal));
   }
 
@@ -119,7 +135,16 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream(
     profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
   }
 
-  return std::move(result);
+  std::vector buffers_to_free;
+  for (ShapeTree& argument : arguments) {
+    for (std::pair& buffer : argument) {
+      auto maybe_owning_buffer = buffer.second.Release();
+      if (maybe_owning_buffer) {
+        buffers_to_free.push_back(std::move(*maybe_owning_buffer));
+      }
+    }
+  }
+  return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {});
 }
 
 /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h
index ba010de76bd..1bea6773fdd 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.h
+++ b/tensorflow/compiler/xla/service/interpreter/executable.h
@@ -46,9 +46,9 @@ class InterpreterExecutable : public Executable {
                         std::unique_ptr evaluator);
   ~InterpreterExecutable() override;
 
-  StatusOr ExecuteAsyncOnStream(
+  StatusOr ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      absl::Span arguments,
+      std::vector> arguments,
       HloExecutionProfile* hlo_execution_profile) override
       LOCKS_EXCLUDED(evaluator_lock_);
 
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
index 5fe5fea71ac..c4bf48bcc00 100644
--- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
@@ -17,7 +17,8 @@ limitations under the License.
 #include "absl/types/variant.h"
 namespace xla {
 
-tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() {
+tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase()
+    const {
   if (HasOwnership()) {
     return *absl::get(mem_);
   } else {
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
index 8edd64cf681..7d23d178130 100644
--- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
@@ -49,7 +49,7 @@ class MaybeOwningDeviceMemory {
 
   // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The
   // caller of this function is *not* responsible for freeing the memory.
-  tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase();
+  tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase() const;
 
   // Release the tensorflow::se::OwningDeviceMemory without freeing it, and
   // moves the ownership of the memory buffer from the object to the caller.

From 260f7e04e51613d1c9a530d4c7d4a3dc1ca0a746 Mon Sep 17 00:00:00 2001
From: Eugene Brevdo 
Date: Tue, 3 Dec 2019 16:57:09 -0800
Subject: [PATCH 262/279] External workspaces can use pybind_extension - extend
 visibility for pybind libs.

This is required for projects that wish to build pybind extensions and have
tensorflow as a sub-workspace.  There is no other way to extend visibility to outer
workspaces in bazel.

PiperOrigin-RevId: 283652336
Change-Id: I749ef20c4f2da382e8c26d7e6679ab68e2cfb5af
---
 tensorflow/python/BUILD   | 48 +++++++++++++++++++++++----------------
 tensorflow/tensorflow.bzl |  4 ++++
 2 files changed, 33 insertions(+), 19 deletions(-)

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 4518aeca3bf..e9e74e85ffa 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3,6 +3,21 @@
 # Public targets:
 #  ":platform" - Low-level and platform-specific Python code.
 
+load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
+load("//tensorflow:tensorflow.bzl", "pybind_extension")
+load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible")
+load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_cupti_test_flags", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler")  # @unused
+load("//tensorflow/core/platform:build_config_root.bzl", "if_static", "tf_additional_plugin_deps")
+load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
+load(
+    "//third_party/ngraph:build_defs.bzl",
+    "if_ngraph",
+)
+
 visibility = [
     "//engedu/ml/tf_from_scratch:__pkg__",
     "//third_party/cloud_tpu/convergence_tools:__subpackages__",
@@ -19,20 +34,6 @@ visibility = [
     "//bazel_pip/tensorflow/lite/toco/python:__pkg__",
 ]
 
-load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
-load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
-load("//tensorflow:tensorflow.bzl", "pybind_extension")
-load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
-load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_cupti_test_flags", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler")  # @unused
-load("//tensorflow/core/platform:build_config_root.bzl", "if_static", "tf_additional_plugin_deps")
-load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
-load(
-    "//third_party/ngraph:build_defs.bzl",
-    "if_ngraph",
-)
-
 package(
     default_visibility = visibility,
     licenses = ["notice"],  # Apache 2.0
@@ -415,9 +416,11 @@ cc_library(
     name = "ndarray_tensor_bridge",
     srcs = ["lib/core/ndarray_tensor_bridge.cc"],
     hdrs = ["lib/core/ndarray_tensor_bridge.h"],
-    visibility = visibility + [
-        "//learning/deepmind/courier:__subpackages__",
-    ],
+    visibility = tf_external_workspace_visible(
+        visibility + [
+            "//learning/deepmind/courier:__subpackages__",
+        ],
+    ),
     deps = [
         ":bfloat16_lib",
         ":numpy_lib",
@@ -443,6 +446,7 @@ cc_library(
     name = "pybind11_absl",
     hdrs = ["lib/core/pybind11_absl.h"],
     features = ["-parse_headers"],
+    visibility = tf_external_workspace_visible(visibility),
     deps = [
         "//tensorflow/core/platform:stringpiece",
         "@pybind11",
@@ -453,6 +457,7 @@ cc_library(
     name = "pybind11_lib",
     hdrs = ["lib/core/pybind11_lib.h"],
     features = ["-parse_headers"],
+    visibility = tf_external_workspace_visible(visibility),
     deps = [
         "@pybind11",
     ],
@@ -465,6 +470,7 @@ cc_library(
         "//tensorflow/c:headers",
     ],
     features = ["-parse_headers"],
+    visibility = tf_external_workspace_visible(visibility),
     deps = [
         ":py_exception_registry",
         "//tensorflow/c:tf_status_headers",
@@ -479,6 +485,7 @@ cc_library(
     name = "pybind11_proto",
     hdrs = ["lib/core/pybind11_proto.h"],
     features = ["-parse_headers"],
+    visibility = tf_external_workspace_visible(visibility),
     deps = [
         "@com_google_absl//absl/strings",
         "@pybind11",
@@ -780,9 +787,9 @@ cc_library(
     name = "ndarray_tensor",
     srcs = ["lib/core/ndarray_tensor.cc"],
     hdrs = ["lib/core/ndarray_tensor.h"],
-    visibility = visibility + [
+    visibility = tf_external_workspace_visible(visibility + [
         "//learning/deepmind/courier:__subpackages__",
-    ],
+    ]),
     deps = [
         ":bfloat16_lib",
         ":ndarray_tensor_bridge",
@@ -5598,17 +5605,20 @@ filegroup(
 cc_import(
     name = "_pywrap_tensorflow_internal_linux",
     shared_library = "//tensorflow/python:lib_pywrap_tensorflow_internal.so",
+    visibility = tf_external_workspace_visible(visibility),
 )
 
 cc_import(
     name = "_pywrap_tensorflow_internal_macos",
     shared_library = "//tensorflow/python:lib_pywrap_tensorflow_internal.dylib",
+    visibility = tf_external_workspace_visible(visibility),
 )
 
 cc_import(
     name = "_pywrap_tensorflow_internal_windows",
     interface_library = "//tensorflow/python:pywrap_tensorflow_import_lib_file",
     shared_library = "//tensorflow/python:_pywrap_tensorflow_internal.dll",
+    visibility = tf_external_workspace_visible(visibility),
 )
 
 # Rename the import library for _pywrap_tensorflow_internal.pyd to _pywrap_tensorflow_internal.lib
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 11598d9885e..5f9f2296c3c 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -2576,3 +2576,7 @@ def if_mlir(if_true, if_false = []):
 
 def tfcompile_extra_flags():
     return ""
+
+def tf_external_workspace_visible(visibility):
+    # External workspaces can see this target.
+    return ["//visibility:public"]

From 6a70aa6d438259cabd23c09808db4cf51a2e5377 Mon Sep 17 00:00:00 2001
From: Yanhui Liang 
Date: Tue, 3 Dec 2019 17:00:08 -0800
Subject: [PATCH 263/279] Add notes in doc string to highlight only `trainable`
 attribute can be modified after after the layer has been called once.

PiperOrigin-RevId: 283652789
Change-Id: I467dddc39c1ac90a3c981f69ac9947397587d337
---
 tensorflow/python/keras/layers/convolutional.py | 3 +++
 tensorflow/python/keras/layers/core.py          | 2 ++
 tensorflow/python/keras/layers/local.py         | 6 ++++++
 3 files changed, 11 insertions(+)

diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index 3a48d2339f3..fefbd1951e9 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -54,6 +54,9 @@ class Conv(Layer):
   a bias vector is created and added to the outputs. Finally, if
   `activation` is not `None`, it is applied to the outputs as well.
 
+  Note: layer attributes cannot be modified after the layer has been called
+  once (except the `trainable` attribute).
+
   Arguments:
     rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
     filters: Integer, the dimensionality of the output space (i.e. the number
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index aad66429b75..600793aeb64 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -964,6 +964,8 @@ class Dense(Layer):
 
   Note: If the input to the layer has a rank greater than 2, then
   it is flattened prior to the initial dot product with `kernel`.
+  Besides, layer attributes cannot be modified after the layer has been called
+  once (except the `trainable` attribute).
 
   Example:
 
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index d94092023aa..ec7392e754e 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -41,6 +41,9 @@ class LocallyConnected1D(Layer):
   that is, a different set of filters is applied at each different patch
   of the input.
 
+  Note: layer attributes cannot be modified after the layer has been called
+  once (except the `trainable` attribute).
+
   Example:
   ```python
       # apply a unshared weight convolution 1d of length 3 to a sequence with
@@ -340,6 +343,9 @@ class LocallyConnected2D(Layer):
   that is, a different set of filters is applied at each
   different patch of the input.
 
+  Note: layer attributes cannot be modified after the layer has been called
+  once (except the `trainable` attribute).
+
   Examples:
   ```python
       # apply a 3x3 unshared weights convolution with 64 output filters on a

From ff26f1dc3e4a4937197a213f36753db0726225ed Mon Sep 17 00:00:00 2001
From: Tian Lin 
Date: Tue, 3 Dec 2019 17:17:30 -0800
Subject: [PATCH 264/279] Fix broken links.

PiperOrigin-RevId: 283655522
Change-Id: I99cfc2993719926f1f653ada7f33dc09f358e5d6
---
 tensorflow/lite/g3doc/guide/hosted_models.md | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/lite/g3doc/guide/hosted_models.md b/tensorflow/lite/g3doc/guide/hosted_models.md
index ff3feb2e113..8bfdaf94538 100644
--- a/tensorflow/lite/g3doc/guide/hosted_models.md
+++ b/tensorflow/lite/g3doc/guide/hosted_models.md
@@ -13,12 +13,12 @@ models to find the optimal balance between size, performance, and accuracy.
 ## Image classification
 
 For more information about image classification, see
-Image classification.
+Image classification.
 
 ## Question and Answer
 
 For more information about text classification with Mobile BERT, see
-Question And Answer.
+Question And Answer.
 
 ### Quantized models
 

From 06b3cc42cea5c2cde01104818796ba72b7abff5a Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache 
Date: Tue, 3 Dec 2019 17:51:34 -0800
Subject: [PATCH 265/279] Refactor dependencies to expose Vector
 transformations as patterns - NFC

This CL refactors some of the MLIR vector dependencies to allow decoupling VectorOps, vector analysis, vector transformations and vector conversions from each other.
This makes the system more modular and allows extracting VectorToVector into VectorTransforms that do not depend on vector conversions.

This refactoring exhibited a bunch of cyclic library dependencies that have been cleaned up.

PiperOrigin-RevId: 283660308
Change-Id: I4bdd9cd994b8bfef20b8c25f43ea503fa5794512
---
 third_party/mlir/BUILD                        | 67 ++++++++++++---
 .../mlir/include/mlir/Analysis/LoopAnalysis.h |  9 +-
 .../ConvertVectorToLLVM.h}                    | 23 +-----
 .../VectorToLoops/ConvertVectorToLoops.h      | 36 ++++++++
 .../VectorOps/Utils.h}                        |  8 +-
 .../mlir/Dialect/VectorOps/VectorTransforms.h | 82 +++++++++++++++++++
 third_party/mlir/include/mlir/EDSC/Builders.h |  1 -
 .../mlir/include/mlir/EDSC/Intrinsics.h       |  1 -
 .../mlir/lib/Analysis/LoopAnalysis.cpp        | 21 ++---
 .../mlir/lib/Analysis/SliceAnalysis.cpp       |  1 -
 .../mlir/lib/Analysis/VectorAnalysis.cpp      |  2 +-
 .../mlir/lib/Conversion/CMakeLists.txt        |  3 +-
 .../Conversion/LinalgToLLVM/LinalgToLLVM.cpp  |  2 +-
 .../VectorConversions/CMakeLists.txt          | 18 ----
 .../Conversion/VectorToLLVM/CMakeLists.txt    | 15 ++++
 .../ConvertVectorToLLVM.cpp}                  |  2 +-
 .../Conversion/VectorToLoops/CMakeLists.txt   | 15 ++++
 .../ConvertVectorToLoops.cpp}                 |  5 +-
 .../mlir/lib/Dialect/VectorOps/CMakeLists.txt |  2 +
 .../VectorOps}/VectorToVector.cpp             |  7 +-
 third_party/mlir/lib/EDSC/Intrinsics.cpp      |  1 -
 .../lib/Transforms/MaterializeVectors.cpp     |  2 +-
 third_party/mlir/lib/Transforms/Vectorize.cpp | 14 +++-
 third_party/mlir/test/BUILD                   |  5 +-
 .../mlir/test/lib/Transforms/CMakeLists.txt   |  2 +-
 ...rs.cpp => TestVectorToLoopsConversion.cpp} | 15 ++--
 .../TestVectorToVectorConversion.cpp          |  2 +-
 .../lib/Transforms/TestVectorizationUtils.cpp |  2 +-
 .../mlir/tools/mlir-opt/CMakeLists.txt        |  5 +-
 29 files changed, 266 insertions(+), 102 deletions(-)
 rename third_party/mlir/include/mlir/Conversion/{VectorConversions/VectorConversions.h => VectorToLLVM/ConvertVectorToLLVM.h} (54%)
 create mode 100644 third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h
 rename third_party/mlir/include/mlir/{Analysis/VectorAnalysis.h => Dialect/VectorOps/Utils.h} (96%)
 create mode 100644 third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
 delete mode 100644 third_party/mlir/lib/Conversion/VectorConversions/CMakeLists.txt
 create mode 100644 third_party/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
 rename third_party/mlir/lib/Conversion/{VectorConversions/VectorToLLVM.cpp => VectorToLLVM/ConvertVectorToLLVM.cpp} (99%)
 create mode 100644 third_party/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt
 rename third_party/mlir/lib/Conversion/{VectorConversions/VectorToLoops.cpp => VectorToLoops/ConvertVectorToLoops.cpp} (99%)
 rename third_party/mlir/lib/{Conversion/VectorConversions => Dialect/VectorOps}/VectorToVector.cpp (98%)
 rename third_party/mlir/test/lib/Transforms/{TestLowerVectorTransfers.cpp => TestVectorToLoopsConversion.cpp} (71%)

diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index a522065d72a..26e03c46df9 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -196,13 +196,11 @@ cc_library(
     includes = ["include"],
     deps = [
         ":AffineOps",
-        ":Analysis",
         ":IR",
         ":LoopOps",
         ":StandardOps",
         ":Support",
         ":TransformUtils",
-        ":VectorOps",
         "@llvm//:support",
     ],
 )
@@ -475,15 +473,20 @@ cc_library(
     name = "VectorOps",
     srcs = [
         "lib/Dialect/VectorOps/VectorOps.cpp",
+        "lib/Dialect/VectorOps/VectorToVector.cpp",
     ],
     hdrs = [
+        "include/mlir/Dialect/VectorOps/Utils.h",
         "include/mlir/Dialect/VectorOps/VectorOps.h",
+        "include/mlir/Dialect/VectorOps/VectorTransforms.h",
     ],
     includes = ["include"],
     deps = [
+        ":EDSC",
         ":IR",
         ":Support",
         ":VectorOpsIncGen",
+        ":VectorTransformPatterns",
         "@llvm//:support",
     ],
 )
@@ -1302,6 +1305,7 @@ cc_library(
         ":StandardOps",
         ":Support",
         ":TransformUtils",
+        ":VectorAnalysis",
         ":VectorOps",
         "@llvm//:support",
     ],
@@ -1456,7 +1460,6 @@ cc_library(
         "lib/Analysis/TestMemRefDependenceCheck.cpp",
         "lib/Analysis/TestParallelismDetection.cpp",
         "lib/Analysis/Utils.cpp",
-        "lib/Analysis/VectorAnalysis.cpp",
         "lib/Analysis/Verifier.cpp",
     ],
     hdrs = [
@@ -1471,7 +1474,6 @@ cc_library(
         "include/mlir/Analysis/Passes.h",
         "include/mlir/Analysis/SliceAnalysis.h",
         "include/mlir/Analysis/Utils.h",
-        "include/mlir/Analysis/VectorAnalysis.h",
         "include/mlir/Analysis/Verifier.h",
     ],
     includes = ["include"],
@@ -1484,6 +1486,23 @@ cc_library(
         ":Pass",
         ":StandardOps",
         ":Support",
+        "@llvm//:support",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "VectorAnalysis",
+    srcs = [
+        "lib/Analysis/VectorAnalysis.cpp",
+    ],
+    includes = ["include"],
+    deps = [
+        ":AffineOps",
+        ":Analysis",
+        ":IR",
+        ":StandardOps",
+        ":Support",
         ":VectorOps",
         "@llvm//:support",
     ],
@@ -1669,7 +1688,8 @@ cc_library(
         ":StandardToSPIRVConversions",
         ":Support",
         ":Transforms",
-        ":VectorConversions",
+        ":VectorToLLVM",
+        ":VectorToLoops",
         ":ViewOpGraph",
         ":ViewRegionGraph",
         "@llvm//:support",
@@ -2202,7 +2222,7 @@ cc_library(
         ":StandardOps",
         ":Support",
         ":Transforms",
-        ":VectorConversions",
+        ":VectorToLLVM",
         "@llvm//:core",
         "@llvm//:support",
     ],
@@ -2372,18 +2392,40 @@ gentbl(
 )
 
 cc_library(
-    name = "VectorConversions",
+    name = "VectorToLLVM",
     srcs = [
-        "lib/Conversion/VectorConversions/VectorToLLVM.cpp",
-        "lib/Conversion/VectorConversions/VectorToLoops.cpp",
-        "lib/Conversion/VectorConversions/VectorToVector.cpp",  # TODO(transforms?)
+        "lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp",
     ],
     hdrs = [
-        "include/mlir/Conversion/VectorConversions/VectorConversions.h",
+        "include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":EDSC",
+        ":IR",
+        ":LLVMDialect",
+        ":LLVMTransforms",
+        ":Pass",
+        ":StandardOps",
+        ":Support",
+        ":Transforms",
+        ":VectorOps",
+        "@llvm//:core",
+        "@llvm//:support",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "VectorToLoops",
+    srcs = [
+        "lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h",
     ],
     includes = ["include"],
     deps = [
-        ":Analysis",
         ":EDSC",
         ":IR",
         ":LLVMDialect",
@@ -2393,7 +2435,6 @@ cc_library(
         ":Support",
         ":Transforms",
         ":VectorOps",
-        ":VectorTransformPatterns",
         "@llvm//:core",
         "@llvm//:support",
     ],
diff --git a/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h b/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h
index 7763a2bd262..8832c1469bc 100644
--- a/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h
+++ b/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h
@@ -31,8 +31,9 @@ namespace mlir {
 class AffineExpr;
 class AffineForOp;
 class AffineMap;
-class Operation;
 class MemRefType;
+class NestedPattern;
+class Operation;
 class Value;
 
 /// Returns the trip count of the loop as an affine map with its corresponding
@@ -91,14 +92,16 @@ using VectorizableLoopFun = std::function;
 ///   1. no conditionals are nested under the loop;
 ///   2. all nested load/stores are to scalar MemRefs.
 /// TODO(ntv): relax the no-conditionals restriction
-bool isVectorizableLoopBody(AffineForOp loop);
+bool isVectorizableLoopBody(AffineForOp loop,
+                            NestedPattern &vectorTransferMatcher);
 
 /// Checks whether the loop is structurally vectorizable and that all the LoadOp
 /// and StoreOp matched have access indexing functions that are are either:
 ///   1. invariant along the loop induction variable created by 'loop';
 ///   2. varying along at most one memory dimension. If such a unique dimension
 ///      is found, it is written into `memRefDim`.
-bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim);
+bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim,
+                            NestedPattern &vectorTransferMatcher);
 
 /// Checks where SSA dominance would be violated if a for op's body
 /// operations are shifted by the specified shifts. This method checks if a
diff --git a/third_party/mlir/include/mlir/Conversion/VectorConversions/VectorConversions.h b/third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
similarity index 54%
rename from third_party/mlir/include/mlir/Conversion/VectorConversions/VectorConversions.h
rename to third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 56862ca0dad..a87e1c658a6 100644
--- a/third_party/mlir/include/mlir/Conversion/VectorConversions/VectorConversions.h
+++ b/third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -1,4 +1,4 @@
-//===- VectorConversions.h - Utils to convert from the vector dialect -----===//
+//===- ConvertVectorToLLVM.h - Utils to convert from the vector dialect ---===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -14,31 +14,16 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // =============================================================================
-#ifndef MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_
-#define MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_
+#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
+#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
 
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 class LLVMTypeConverter;
-class MLIRContext;
 class ModuleOp;
 template  class OpPassBase;
 
-/// Collect a set of patterns to convert from the Vector dialect to affine loops
-/// surrounding ops in different dialects (vector, std etc).
-/// This is the general place where we want to implement Vector -> Vector and
-/// Vector -> Std legalizations.
-void populateVectorToAffineLoopsConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns);
-
-/// Collect a set of patterns to convert from the Vector dialect to itself.
-/// Should be merged with populateVectorToAffineLoopsConversionPatterns.
-void populateVectorToVectorConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns,
-    ArrayRef coarseVectorShape = {},
-    ArrayRef fineVectorShape = {});
-
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             OwningRewritePatternList &patterns);
@@ -48,4 +33,4 @@ OpPassBase *createLowerVectorToLLVMPass();
 
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_
+#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
diff --git a/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h b/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h
new file mode 100644
index 00000000000..198eaceda41
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h
@@ -0,0 +1,36 @@
+//===- ConvertVectorToLoops.h - Utils to convert from the vector dialect --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_
+#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class MLIRContext;
+class ModuleOp;
+template  class OpPassBase;
+
+/// Collect a set of patterns to convert from the Vector dialect to loops + std.
+void populateVectorToAffineLoopsConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns);
+
+/// Create a pass to convert vector operations to affine loops + std dialect.
+OpPassBase *createLowerVectorToLoopsPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_
diff --git a/third_party/mlir/include/mlir/Analysis/VectorAnalysis.h b/third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h
similarity index 96%
rename from third_party/mlir/include/mlir/Analysis/VectorAnalysis.h
rename to third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h
index 350bdfd8cce..2cff8795304 100644
--- a/third_party/mlir/include/mlir/Analysis/VectorAnalysis.h
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h
@@ -1,4 +1,4 @@
-//===- VectorAnalysis.h - Analysis for Vectorization -------*- C++ -*-=======//
+//===- Utils.h - VectorOps Utils ----------------------------*- C++ -*-=======//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -15,8 +15,8 @@
 // limitations under the License.
 // =============================================================================
 
-#ifndef MLIR_ANALYSIS_VECTORANALYSIS_H_
-#define MLIR_ANALYSIS_VECTORANALYSIS_H_
+#ifndef MLIR_DIALECT_VECTOROPS_UTILS_H_
+#define MLIR_DIALECT_VECTOROPS_UTILS_H_
 
 #include "mlir/Support/LLVM.h"
 
@@ -140,4 +140,4 @@ bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
 } // end namespace matcher
 } // end namespace mlir
 
-#endif // MLIR_ANALYSIS_VECTORANALYSIS_H_
+#endif // MLIR_DIALECT_VECTOROPS_UTILS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
new file mode 100644
index 00000000000..2c2e4e7c4fa
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
@@ -0,0 +1,82 @@
+//===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef DIALECT_VECTOROPS_VECTORTRANSFORMS_H_
+#define DIALECT_VECTOROPS_VECTORTRANSFORMS_H_
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class MLIRContext;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to convert from the Vector dialect to itself.
+/// Should be merged with populateVectorToAffineLoopsConversionPatterns.
+void populateVectorToVectorConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    ArrayRef coarseVectorShape = {},
+    ArrayRef fineVectorShape = {});
+
+////////////////////////////////////////////////////////////////////////////////
+// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite
+// patterns. As such, they must not call into `rewriter.erase/replace` APIs and
+// it is the responsibility of the enclosing PatternRewriter to erase on
+// success.
+////////////////////////////////////////////////////////////////////////////////
+
+namespace vector {
+
+// Entry point for unrolling declarative pattern rewrites.
+// `op` is unrolled to the `targetShape` as follows, for each of its operands:
+//   1. the unrolled type `unrolledVectorType` and number of unrolled instances
+//   `numUnrolledInstances` are computed from the `targetShape`. For now it is
+//   assumed the unrolling factors divide the vector sizes.
+//   2. a fakeFork cast op is inserted that takes the operand and returns
+//   `numUnrolledInstances` results of type `unrolledVectorType`.
+//   3. the original op is cloned `numUnrolledInstances` times, once for each
+//   result of the fakeFork cast op.
+//   4. a fakeJoin cast op takes all these results and merges them into a single
+//   aggregate vector result whose size matches the original non-unrolled op
+//   operand types.
+//
+// Example:
+//
+//    opA(operand0, operand1)  // numUnrolledInstances = 3
+//
+//            operand0                   operand1
+//               |                          |
+//             fork                       fork
+//        <----------gather all fork ops --------->
+//              /|\                        /|\
+//          f00 f01 f02                f10 f11 f12
+//        <---------- clone op 3 times --------->
+//          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
+//                 \            |            /
+//      <-------------------- join ------------------------->
+//
+// Other local patterns then kick in iteratively (including DCE) and compose
+// until all the fakeFork and fakeJoin ops are removed.
+//
+// This will be extended in the future to support more advanced use cases than
+// simple pointwise ops.
+Value *unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
+                                        ArrayRef targetShape);
+
+} // namespace vector
+} // namespace mlir
+
+#endif // DIALECT_VECTOROPS_VECTORTRANSFORMS_H_
diff --git a/third_party/mlir/include/mlir/EDSC/Builders.h b/third_party/mlir/include/mlir/EDSC/Builders.h
index 1927ce60eab..5940f1c244f 100644
--- a/third_party/mlir/include/mlir/EDSC/Builders.h
+++ b/third_party/mlir/include/mlir/EDSC/Builders.h
@@ -26,7 +26,6 @@
 #include "mlir/Dialect/AffineOps/AffineOps.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Transforms/FoldUtils.h"
 
diff --git a/third_party/mlir/include/mlir/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
index 6e1c49f66cc..468cc1c4240 100644
--- a/third_party/mlir/include/mlir/EDSC/Intrinsics.h
+++ b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
@@ -215,7 +215,6 @@ using select = ValueBuilder;
 using std_load = ValueBuilder;
 using std_store = OperationBuilder;
 using subi = ValueBuilder;
-using vector_type_cast = ValueBuilder;
 using view = ValueBuilder;
 
 /// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
diff --git a/third_party/mlir/lib/Analysis/LoopAnalysis.cpp b/third_party/mlir/lib/Analysis/LoopAnalysis.cpp
index f01e548a3df..b297a63cb62 100644
--- a/third_party/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/third_party/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -25,7 +25,6 @@
 #include "mlir/Analysis/AffineStructures.h"
 #include "mlir/Analysis/NestedMatcher.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"
-#include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/Support/MathExtras.h"
 
 #include "llvm/ADT/DenseSet.h"
@@ -273,15 +272,12 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
   return memRefType.getElementType().template isa();
 }
 
-static bool isVectorTransferReadOrWrite(Operation &op) {
-  return isa(op) || isa(op);
-}
-
 using VectorizableOpFun = std::function;
 
 static bool
 isVectorizableLoopBodyWithOpCond(AffineForOp loop,
-                                 VectorizableOpFun isVectorizableOp) {
+                                 VectorizableOpFun isVectorizableOp,
+                                 NestedPattern &vectorTransferMatcher) {
   auto *forOp = loop.getOperation();
 
   // No vectorization across conditionals for now.
@@ -303,9 +299,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop,
     return false;
   }
 
-  auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
   SmallVector vectorTransfersMatched;
-  vectorTransfers.match(forOp, &vectorTransfersMatched);
+  vectorTransferMatcher.match(forOp, &vectorTransfersMatched);
   if (!vectorTransfersMatched.empty()) {
     return false;
   }
@@ -331,18 +326,20 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop,
   return true;
 }
 
-bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim) {
+bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim,
+                                  NestedPattern &vectorTransferMatcher) {
   VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) {
     auto load = dyn_cast(op);
     auto store = dyn_cast(op);
     return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim)
                 : isContiguousAccess(loop.getInductionVar(), store, memRefDim);
   });
-  return isVectorizableLoopBodyWithOpCond(loop, fun);
+  return isVectorizableLoopBodyWithOpCond(loop, fun, vectorTransferMatcher);
 }
 
-bool mlir::isVectorizableLoopBody(AffineForOp loop) {
-  return isVectorizableLoopBodyWithOpCond(loop, nullptr);
+bool mlir::isVectorizableLoopBody(AffineForOp loop,
+                                  NestedPattern &vectorTransferMatcher) {
+  return isVectorizableLoopBodyWithOpCond(loop, nullptr, vectorTransferMatcher);
 }
 
 /// Checks whether SSA dominance would be violated if a for op's body
diff --git a/third_party/mlir/lib/Analysis/SliceAnalysis.cpp b/third_party/mlir/lib/Analysis/SliceAnalysis.cpp
index 718db0e76d2..700321ebb40 100644
--- a/third_party/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/third_party/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -20,7 +20,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Analysis/VectorAnalysis.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/IR/Function.h"
diff --git a/third_party/mlir/lib/Analysis/VectorAnalysis.cpp b/third_party/mlir/lib/Analysis/VectorAnalysis.cpp
index 666ee071c63..42d3f10b14c 100644
--- a/third_party/mlir/lib/Analysis/VectorAnalysis.cpp
+++ b/third_party/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -15,11 +15,11 @@
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/Analysis/VectorAnalysis.h"
 #include "mlir/Analysis/AffineAnalysis.h"
 #include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/VectorOps/Utils.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IntegerSet.h"
diff --git a/third_party/mlir/lib/Conversion/CMakeLists.txt b/third_party/mlir/lib/Conversion/CMakeLists.txt
index 6d370f714e2..c791d214d30 100644
--- a/third_party/mlir/lib/Conversion/CMakeLists.txt
+++ b/third_party/mlir/lib/Conversion/CMakeLists.txt
@@ -8,4 +8,5 @@ add_subdirectory(LoopsToGPU)
 add_subdirectory(LoopToStandard)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
-add_subdirectory(VectorConversions)
+add_subdirectory(VectorToLLVM)
+add_subdirectory(VectorToLoops)
diff --git a/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index ff516d7ef29..709dd3af7f0 100644
--- a/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -20,7 +20,7 @@
 #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/Conversion/VectorConversions/VectorConversions.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
diff --git a/third_party/mlir/lib/Conversion/VectorConversions/CMakeLists.txt b/third_party/mlir/lib/Conversion/VectorConversions/CMakeLists.txt
deleted file mode 100644
index c8d699e4462..00000000000
--- a/third_party/mlir/lib/Conversion/VectorConversions/CMakeLists.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-add_llvm_library(MLIRVectorConversions
-  VectorToLLVM.cpp
-  VectorToLoops.cpp
-  VectorToVector.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorConversions
-)
-set(LIBS
-  MLIRLLVMIR
-  MLIRTransforms
-  LLVMCore
-  LLVMSupport
-  )
-
-add_dependencies(MLIRVectorConversions ${LIBS})
-add_dependencies(MLIRVectorConversions MLIRVectorTransformPatternsIncGen)
-target_link_libraries(MLIRVectorConversions ${LIBS})
diff --git a/third_party/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/third_party/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000..2aaec68f6c4
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRVectorToLLVM
+  ConvertVectorToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM
+)
+set(LIBS
+  MLIRLLVMIR
+  MLIRTransforms
+  LLVMCore
+  LLVMSupport
+  )
+
+add_dependencies(MLIRVectorToLLVM ${LIBS})
+target_link_libraries(MLIRVectorToLLVM ${LIBS})
diff --git a/third_party/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp b/third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
similarity index 99%
rename from third_party/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp
rename to third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5420ad05ae1..7221998ce25 100644
--- a/third_party/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -17,7 +17,7 @@
 
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/Conversion/VectorConversions/VectorConversions.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/Attributes.h"
diff --git a/third_party/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt b/third_party/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt
new file mode 100644
index 00000000000..e213bc9bcce
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRVectorToLoops
+  ConvertVectorToLoops.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLoops
+)
+set(LIBS
+  MLIRLLVMIR
+  MLIRTransforms
+  LLVMCore
+  LLVMSupport
+  )
+
+add_dependencies(MLIRVectorToLoops ${LIBS})
+target_link_libraries(MLIRVectorToLoops ${LIBS})
diff --git a/third_party/mlir/lib/Conversion/VectorConversions/VectorToLoops.cpp b/third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
similarity index 99%
rename from third_party/mlir/lib/Conversion/VectorConversions/VectorToLoops.cpp
rename to third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
index 74479b9922d..43ad91ce878 100644
--- a/third_party/mlir/lib/Conversion/VectorConversions/VectorToLoops.cpp
+++ b/third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
@@ -21,8 +21,8 @@
 
 #include 
 
-#include "mlir/Conversion/VectorConversions/VectorConversions.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h"
 #include "mlir/EDSC/Builders.h"
 #include "mlir/EDSC/Helpers.h"
 #include "mlir/IR/AffineExpr.h"
@@ -41,6 +41,8 @@ using vector::TransferWriteOp;
 
 namespace {
 
+using vector_type_cast = edsc::intrinsics::ValueBuilder;
+
 /// Implements lowering of TransferReadOp and TransferWriteOp to a
 /// proper abstraction for the hardware.
 ///
@@ -356,7 +358,6 @@ PatternMatchResult VectorTransferRewriter::matchAndRewrite(
 }
 } // namespace
 
-/// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToAffineLoopsConversionPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns) {
   patterns.insert,
diff --git a/third_party/mlir/lib/Dialect/VectorOps/CMakeLists.txt b/third_party/mlir/lib/Dialect/VectorOps/CMakeLists.txt
index 590eeed6f41..754e62de14e 100644
--- a/third_party/mlir/lib/Dialect/VectorOps/CMakeLists.txt
+++ b/third_party/mlir/lib/Dialect/VectorOps/CMakeLists.txt
@@ -1,11 +1,13 @@
 add_llvm_library(MLIRVectorOps
   DialectRegistration.cpp
   VectorOps.cpp
+  VectorToVector.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps
   )
 
 add_dependencies(MLIRVectorOps MLIRVectorOpsIncGen)
+add_dependencies(MLIRVectorOps MLIRVectorTransformPatternsIncGen)
 
 target_link_libraries(MLIRVectorOps MLIRIR)
diff --git a/third_party/mlir/lib/Conversion/VectorConversions/VectorToVector.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp
similarity index 98%
rename from third_party/mlir/lib/Conversion/VectorConversions/VectorToVector.cpp
rename to third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp
index 74687d449af..1e2e651189f 100644
--- a/third_party/mlir/lib/Conversion/VectorConversions/VectorToVector.cpp
+++ b/third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp
@@ -21,10 +21,9 @@
 
 #include 
 
-#include "mlir/Analysis/VectorAnalysis.h"
-#include "mlir/Conversion/VectorConversions/VectorConversions.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/VectorOps/Utils.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/Dialect/VectorOps/VectorTransforms.h"
 #include "mlir/EDSC/Builders.h"
 #include "mlir/EDSC/Helpers.h"
 #include "mlir/IR/AffineExpr.h"
@@ -198,7 +197,7 @@ static bool hasShape(Value *v, ArrayRef shape) {
 //
 // This will be extended in the future to support more advanced use cases than
 // simple pointwise ops.
-static Value *unrollSingleResultOpMatchingType(PatternRewriter &builder,
+Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
                                                Operation *op,
                                                ArrayRef targetShape) {
   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
diff --git a/third_party/mlir/lib/EDSC/Intrinsics.cpp b/third_party/mlir/lib/EDSC/Intrinsics.cpp
index f80726866fc..1b19f9aa0bf 100644
--- a/third_party/mlir/lib/EDSC/Intrinsics.cpp
+++ b/third_party/mlir/lib/EDSC/Intrinsics.cpp
@@ -16,7 +16,6 @@
 // =============================================================================
 
 #include "mlir/EDSC/Intrinsics.h"
-#include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/EDSC/Builders.h"
 #include "mlir/IR/AffineExpr.h"
 
diff --git a/third_party/mlir/lib/Transforms/MaterializeVectors.cpp b/third_party/mlir/lib/Transforms/MaterializeVectors.cpp
index 874eac6e4e6..33f5927d88e 100644
--- a/third_party/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/third_party/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -26,9 +26,9 @@
 #include "mlir/Analysis/NestedMatcher.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/Utils.h"
-#include "mlir/Analysis/VectorAnalysis.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/VectorOps/Utils.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
diff --git a/third_party/mlir/lib/Transforms/Vectorize.cpp b/third_party/mlir/lib/Transforms/Vectorize.cpp
index 2a0ce092f81..c1e0a9c0e13 100644
--- a/third_party/mlir/lib/Transforms/Vectorize.cpp
+++ b/third_party/mlir/lib/Transforms/Vectorize.cpp
@@ -24,9 +24,9 @@
 #include "mlir/Analysis/NestedMatcher.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/Utils.h"
-#include "mlir/Analysis/VectorAnalysis.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/VectorOps/Utils.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
@@ -589,6 +589,13 @@ makePatterns(const llvm::DenseSet ¶llelLoops, int vectorRank,
   }
 }
 
+static NestedPattern &vectorTransferPattern() {
+  static auto pattern = matcher::Op([](Operation &op) {
+    return isa(op) || isa(op);
+  });
+  return pattern;
+}
+
 namespace {
 
 /// Base state for the vectorize pass.
@@ -893,7 +900,8 @@ isVectorizableLoopPtrFactory(const llvm::DenseSet ¶llelLoops,
     if (parallelIt == parallelLoops.end())
       return false;
     int memRefDim = -1;
-    auto vectorizableBody = isVectorizableLoopBody(loop, &memRefDim);
+    auto vectorizableBody =
+        isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern());
     if (!vectorizableBody)
       return false;
     return memRefDim == -1 || fastestVaryingMemRefDimension == -1 ||
@@ -1172,7 +1180,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m,
   // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
   // TODO(ntv): implement a non-greedy profitability analysis that keeps only
   // non-intersecting patterns.
-  if (!isVectorizableLoopBody(loop)) {
+  if (!isVectorizableLoopBody(loop, vectorTransferPattern())) {
     LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
     return failure();
   }
diff --git a/third_party/mlir/test/BUILD b/third_party/mlir/test/BUILD
index 85680008d6d..25f7b8399eb 100644
--- a/third_party/mlir/test/BUILD
+++ b/third_party/mlir/test/BUILD
@@ -133,9 +133,9 @@ cc_library(
         "lib/Transforms/TestLoopFusion.cpp",
         "lib/Transforms/TestLoopMapping.cpp",
         "lib/Transforms/TestLoopParametricTiling.cpp",
-        "lib/Transforms/TestLowerVectorTransfers.cpp",
         "lib/Transforms/TestMemRefStrideCalculation.cpp",
         "lib/Transforms/TestOpaqueLoc.cpp",
+        "lib/Transforms/TestVectorToLoopsConversion.cpp",
         "lib/Transforms/TestVectorToVectorConversion.cpp",
         "lib/Transforms/TestVectorizationUtils.cpp",
     ],
@@ -155,8 +155,9 @@ cc_library(
         "@local_config_mlir//:Support",
         "@local_config_mlir//:TransformUtils",
         "@local_config_mlir//:Transforms",
-        "@local_config_mlir//:VectorConversions",
         "@local_config_mlir//:VectorOps",
+        "@local_config_mlir//:VectorToLLVM",
+        "@local_config_mlir//:VectorToLoops",
     ],
     alwayslink = 1,
 )
diff --git a/third_party/mlir/test/lib/Transforms/CMakeLists.txt b/third_party/mlir/test/lib/Transforms/CMakeLists.txt
index 2d482e5f1a5..8bc9c736187 100644
--- a/third_party/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/third_party/mlir/test/lib/Transforms/CMakeLists.txt
@@ -6,9 +6,9 @@ add_llvm_library(MLIRTestTransforms
   TestLinalgTransforms.cpp
   TestLoopMapping.cpp
   TestLoopParametricTiling.cpp
-  TestLowerVectorTransfers.cpp
   TestOpaqueLoc.cpp
   TestMemRefStrideCalculation.cpp
+  TestVectorToLoopsConversion.cpp
   TestVectorToVectorConversion.cpp
   TestVectorizationUtils.cpp
 
diff --git a/third_party/mlir/test/lib/Transforms/TestLowerVectorTransfers.cpp b/third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp
similarity index 71%
rename from third_party/mlir/test/lib/Transforms/TestLowerVectorTransfers.cpp
rename to third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp
index 8341777f6a4..e5f5f749bd0 100644
--- a/third_party/mlir/test/lib/Transforms/TestLowerVectorTransfers.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp
@@ -1,4 +1,4 @@
-//===- TestLowerVectorTransfers.cpp - Test VectorTransfers lowering -------===//
+//===- TestVectorToLoopsConversion.cpp - Test VectorTransfers lowering ----===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -17,7 +17,7 @@
 
 #include 
 
-#include "mlir/Conversion/VectorConversions/VectorConversions.h"
+#include "mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
@@ -26,8 +26,8 @@ using namespace mlir;
 
 namespace {
 
-struct TestLowerVectorTransfersPass
-    : public FunctionPass {
+struct TestVectorToLoopsPass
+    : public FunctionPass {
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     auto *context = &getContext();
@@ -38,7 +38,6 @@ struct TestLowerVectorTransfersPass
 
 } // end anonymous namespace
 
-static PassRegistration
-    pass("test-affine-lower-vector-transfers",
-         "Materializes vector transfer ops to a "
-         "proper abstraction for the hardware");
+static PassRegistration
+    pass("test-convert-vector-to-loops",
+         "Converts vector transfer ops to loops over scalars and vector casts");
diff --git a/third_party/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp b/third_party/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp
index 2550796ade2..9f9b8a554fe 100644
--- a/third_party/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp
@@ -18,7 +18,7 @@
 
 #include 
 
-#include "mlir/Conversion/VectorConversions/VectorConversions.h"
+#include "mlir/Dialect/VectorOps/VectorTransforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
diff --git a/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
index f0f1f6b0b23..7efc74f2304 100644
--- a/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
@@ -22,8 +22,8 @@
 #include "mlir/Analysis/AffineAnalysis.h"
 #include "mlir/Analysis/NestedMatcher.h"
 #include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Analysis/VectorAnalysis.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/VectorOps/Utils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/StandardTypes.h"
diff --git a/third_party/mlir/tools/mlir-opt/CMakeLists.txt b/third_party/mlir/tools/mlir-opt/CMakeLists.txt
index e38b43d59b8..b30d7e39ce8 100644
--- a/third_party/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/third_party/mlir/tools/mlir-opt/CMakeLists.txt
@@ -21,7 +21,7 @@ set(LIBS
   MLIRAffineToStandard
   MLIRLoopsToGPU
   MLIRLinalgToLLVM
-  
+
   MLIRLoopToStandard
   MLIREDSC
   MLIRFxpMathOps
@@ -51,7 +51,8 @@ set(LIBS
   MLIRTestTransforms
   MLIRSupport
   MLIRVectorOps
-  MLIRVectorConversions
+  MLIRVectorToLLVM
+  MLIRVectorToLoops
 )
 if(MLIR_CUDA_CONVERSIONS_ENABLED)
   list(APPEND LIBS

From ba102264d4eb4ed8814384e8d2de8767ef8e9693 Mon Sep 17 00:00:00 2001
From: Thomas O'Malley 
Date: Tue, 3 Dec 2019 18:48:28 -0800
Subject: [PATCH 266/279] Add error check if Lambda layer function doesn't use
 unique Variable names.

PiperOrigin-RevId: 283667330
Change-Id: I3d74b2876b48a42ebab495b3c19abc9a4ffd6203
---
 tensorflow/python/keras/layers/core.py      | 12 ++++++++++++
 tensorflow/python/keras/layers/core_test.py | 11 +++++++++++
 2 files changed, 23 insertions(+)

diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 600793aeb64..ee44fcbd946 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -818,6 +818,8 @@ class Lambda(Layer):
     return nest.map_structure(_add_batch, output_shapes)
 
   def call(self, inputs, mask=None, training=None):
+    # Disallow two variables with the same name.
+    self._variables_added_in_call = set()
     arguments = self.arguments
     if self._fn_expects_mask_arg:
       arguments['mask'] = mask
@@ -828,8 +830,18 @@ class Lambda(Layer):
 
   def _variable_creator(self, next_creator, **kwargs):
     name = kwargs['name']
+
+    # Variable named "name" already created in this invocation of `call`.
+    if name in self._variables_added_in_call:
+      raise RuntimeError('`Variable`s in a `Lambda` layer must have unique '
+                         'names, found duplicate name: {}'.format(name))
+    self._variables_added_in_call.add(name)
+
+    # Reuse Variables across invocations of `call`.
     if name in self._variable_dict:
       return self._variable_dict[name]
+
+    # Variable was never created before.
     var = next_creator(**kwargs)
     self._variable_dict[name] = var
     if var.trainable:
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index aa7b42d0e95..05f89053fcb 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -236,6 +236,17 @@ class LambdaLayerTest(keras_parameterized.TestCase):
     self.assertLen(layer.trainable_weights, 1)
     self.assertEqual(layer.trainable_weights[0].name, 'lambda/multiplier:0')
 
+  def test_lambda_with_duplicate_variable_names(self):
+
+    def fn(x):
+      v1 = variables.Variable(2.)
+      v2 = variables.Variable(1.)
+      return x * v1 * v2
+
+    layer = keras.layers.Lambda(fn)
+    with self.assertRaisesRegexp(RuntimeError, 'must have unique names'):
+      layer(np.ones((10, 10), 'float32'))
+
   def test_lambda_with_training_arg(self):
 
     def fn(x, training=True):

From d219e9017148c4f8bc8f0d3db638bf7ae2309397 Mon Sep 17 00:00:00 2001
From: Brian Zhao 
Date: Tue, 3 Dec 2019 20:08:27 -0800
Subject: [PATCH 267/279] Adding test_benchmark build target to
 tf/core/platform/BUILD. This change is part of the refactoring described in
 Tensorflow Build Improvements RFC:
 https://github.com/tensorflow/community/pull/179

PiperOrigin-RevId: 283675621
Change-Id: I355d72b476c3f4222dd3e83768a374bf3cc8beb3
---
 tensorflow/core/BUILD                         |   4 +-
 tensorflow/core/platform/BUILD                |  11 ++
 .../core/platform/default/build_refactor.bzl  |  19 ++++
 .../core/platform/default/test_benchmark.cc   |   2 +-
 .../core/platform/default/test_benchmark.h    | 105 ++++++++++++++++++
 tensorflow/core/platform/test_benchmark.h     |  93 +---------------
 6 files changed, 140 insertions(+), 94 deletions(-)
 create mode 100644 tensorflow/core/platform/default/test_benchmark.h

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index f07251955dc..588420eb1b6 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -573,8 +573,8 @@ cc_library(
     name = "util_reporter",
     srcs = ["util/reporter.cc"],
     hdrs = ["util/reporter.h"],
-    # Not to be used outside this file.
-    visibility = ["//visibility:private"],
+    # This should only be used in tensorflow/core/platform:test_benchmark
+    visibility = ["//tensorflow/core/platform:__subpackages__"],
     deps = [
         ":test_log_proto_impl_cc",
         "//tensorflow/core/platform:env",
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 534a3b4e7cc..8e6fd49d1ab 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -70,6 +70,7 @@ tf_instantiate_platform_libraries(names = [
     "strong_hash",
     "subprocess",
     "test",
+    "test_benchmark",
     "tracing",
     "types",
     "unbounded_work_queue",
@@ -635,6 +636,16 @@ cc_library(
     deps = tf_mobile_aware_deps("test_impl"),
 )
 
+cc_library(
+    name = "test_benchmark",
+    testonly = True,
+    hdrs = ["test_benchmark.h"],
+    deps = [
+        ":platform",
+        ":test_benchmark_impl",
+    ],
+)
+
 cc_library(
     name = "thread_annotations",
     hdrs = ["thread_annotations.h"],
diff --git a/tensorflow/core/platform/default/build_refactor.bzl b/tensorflow/core/platform/default/build_refactor.bzl
index a29cce63fd7..4f11699f766 100644
--- a/tensorflow/core/platform/default/build_refactor.bzl
+++ b/tensorflow/core/platform/default/build_refactor.bzl
@@ -322,6 +322,25 @@ TF_DEFAULT_PLATFORM_LIBRARIES = {
         "tags": ["no_oss", "manual"],
         "visibility": ["//visibility:private"],
     },
+    "test_benchmark": {
+        "name": "test_benchmark_impl",
+        "testonly": True,
+        "srcs": [
+            "//tensorflow/core/platform:default/test_benchmark.cc",
+        ],
+        "hdrs": [
+            "//tensorflow/core/platform:default/test_benchmark.h",
+        ],
+        "deps": [
+            "//tensorflow/core/platform",
+            "//tensorflow/core/platform:env",
+            "//tensorflow/core/platform:macros",
+            "//tensorflow/core/platform:types",
+            "//tensorflow/core:util_reporter",
+        ],
+        "tags": ["no_oss", "manual"],
+        "visibility": ["//visibility:private"],
+    },
     "tracing": {
         "name": "tracing_impl",
         "textual_hdrs": [
diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc
index 77747fc5b79..533c4ac1df1 100644
--- a/tensorflow/core/platform/default/test_benchmark.cc
+++ b/tensorflow/core/platform/default/test_benchmark.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/default/test_benchmark.h"
 
 #include 
 #include 
diff --git a/tensorflow/core/platform/default/test_benchmark.h b/tensorflow/core/platform/default/test_benchmark.h
new file mode 100644
index 00000000000..203a8a045ff
--- /dev/null
+++ b/tensorflow/core/platform/default/test_benchmark.h
@@ -0,0 +1,105 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Simple benchmarking facility.
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_
+
+#include 
+#include 
+
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/types.h"
+
+#define BENCHMARK(n)                                            \
+  static ::tensorflow::testing::Benchmark* TF_BENCHMARK_CONCAT( \
+      __benchmark_, n, __LINE__) TF_ATTRIBUTE_UNUSED =          \
+      (new ::tensorflow::testing::Benchmark(#n, (n)))
+#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c)
+#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c
+
+namespace tensorflow {
+namespace testing {
+
+// The DoNotOptimize(...) function can be used to prevent a value or
+// expression from being optimized away by the compiler. This function is
+// intended to add little to no overhead.
+// See: http://stackoverflow.com/questions/28287064
+//
+// The specific guarantees of DoNotOptimize(x) are:
+//  1) x, and any data it transitively points to, will exist (in a register or
+//     in memory) at the current point in the program.
+//  2) The optimizer will assume that DoNotOptimize(x) could mutate x or
+//     anything it transitively points to (although it actually doesn't).
+//
+// To see this in action:
+//
+//   void BM_multiply(benchmark::State& state) {
+//     int a = 2;
+//     int b = 4;
+//     for (auto _ : state) {
+//       testing::DoNotOptimize(a);
+//       testing::DoNotOptimize(b);
+//       int c = a * b;
+//       testing::DoNotOptimize(c);
+//     }
+//   }
+//   BENCHMARK(BM_multiply);
+//
+// Guarantee (2) applied to 'a' and 'b' prevents the compiler lifting the
+// multiplication outside of the loop. Guarantee (1) applied to 'c' prevents the
+// compiler from optimizing away 'c' as dead code.
+template 
+void DoNotOptimize(const T& var) {
+  asm volatile("" : "+m"(const_cast(var)));
+}
+
+class Benchmark {
+ public:
+  Benchmark(const char* name, void (*fn)(int));
+  Benchmark(const char* name, void (*fn)(int, int));
+  Benchmark(const char* name, void (*fn)(int, int, int));
+
+  Benchmark* Arg(int x);
+  Benchmark* ArgPair(int x, int y);
+  Benchmark* Range(int lo, int hi);
+  Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2);
+  static void Run(const char* pattern);
+
+ private:
+  string name_;
+  int num_args_;
+  std::vector > args_;
+  void (*fn0_)(int) = nullptr;
+  void (*fn1_)(int, int) = nullptr;
+  void (*fn2_)(int, int, int) = nullptr;
+
+  void Register();
+  void Run(int arg1, int arg2, int* run_count, double* run_seconds);
+};
+
+void RunBenchmarks();
+void SetLabel(const std::string& label);
+void BytesProcessed(int64);
+void ItemsProcessed(int64);
+void StartTiming();
+void StopTiming();
+void UseRealTime();
+
+}  // namespace testing
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_
diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h
index 61fcd0d372c..aff1edb2d51 100644
--- a/tensorflow/core/platform/test_benchmark.h
+++ b/tensorflow/core/platform/test_benchmark.h
@@ -17,102 +17,13 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_
 #define TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_
 
-#include 
-#include 
-#include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/platform.h"
-#include "tensorflow/core/platform/types.h"
 
 #if defined(PLATFORM_GOOGLE)
-#include "tensorflow/core/platform/google/build_config/benchmark.h"
-
+#include "tensorflow/core/platform/google/test_benchmark.h"
 #else
-#define BENCHMARK(n)                                            \
-  static ::tensorflow::testing::Benchmark* TF_BENCHMARK_CONCAT( \
-      __benchmark_, n, __LINE__) TF_ATTRIBUTE_UNUSED =          \
-      (new ::tensorflow::testing::Benchmark(#n, (n)))
-#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c)
-#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c
-
+#include "tensorflow/core/platform/default/test_benchmark.h"
 #endif  // PLATFORM_GOOGLE
 
-namespace tensorflow {
-namespace testing {
-
-#if defined(PLATFORM_GOOGLE)
-
-using ::testing::Benchmark;
-using ::testing::DoNotOptimize;
-
-#else
-
-// The DoNotOptimize(...) function can be used to prevent a value or
-// expression from being optimized away by the compiler. This function is
-// intended to add little to no overhead.
-// See: http://stackoverflow.com/questions/28287064
-//
-// The specific guarantees of DoNotOptimize(x) are:
-//  1) x, and any data it transitively points to, will exist (in a register or
-//     in memory) at the current point in the program.
-//  2) The optimizer will assume that DoNotOptimize(x) could mutate x or
-//     anything it transitively points to (although it actually doesn't).
-//
-// To see this in action:
-//
-//   void BM_multiply(benchmark::State& state) {
-//     int a = 2;
-//     int b = 4;
-//     for (auto _ : state) {
-//       testing::DoNotOptimize(a);
-//       testing::DoNotOptimize(b);
-//       int c = a * b;
-//       testing::DoNotOptimize(c);
-//     }
-//   }
-//   BENCHMARK(BM_multiply);
-//
-// Guarantee (2) applied to 'a' and 'b' prevents the compiler lifting the
-// multiplication outside of the loop. Guarantee (1) applied to 'c' prevents the
-// compiler from optimizing away 'c' as dead code.
-template 
-void DoNotOptimize(const T& var) {
-  asm volatile("" : "+m"(const_cast(var)));
-}
-
-class Benchmark {
- public:
-  Benchmark(const char* name, void (*fn)(int));
-  Benchmark(const char* name, void (*fn)(int, int));
-  Benchmark(const char* name, void (*fn)(int, int, int));
-
-  Benchmark* Arg(int x);
-  Benchmark* ArgPair(int x, int y);
-  Benchmark* Range(int lo, int hi);
-  Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2);
-  static void Run(const char* pattern);
-
- private:
-  string name_;
-  int num_args_;
-  std::vector > args_;
-  void (*fn0_)(int) = nullptr;
-  void (*fn1_)(int, int) = nullptr;
-  void (*fn2_)(int, int, int) = nullptr;
-
-  void Register();
-  void Run(int arg1, int arg2, int* run_count, double* run_seconds);
-};
-#endif
-
-void RunBenchmarks();
-void SetLabel(const std::string& label);
-void BytesProcessed(int64);
-void ItemsProcessed(int64);
-void StartTiming();
-void StopTiming();
-void UseRealTime();
-
-}  // namespace testing
-}  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_

From 649d1b34da7a90f0dbc024911751a34bef591290 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 20:21:18 -0800
Subject: [PATCH 268/279] When exporting SavedModel, force functional and
 sequential models to save config even if error occurs when serializing
 layers.

This ensures that the network structure is saved even if a custom layer doesn't define its config.

PiperOrigin-RevId: 283676690
Change-Id: I15e571fcf6a0ca21fd38a2cca16ecf1e690fa5a1
---
 tensorflow/python/keras/BUILD                 | 12 -----
 tensorflow/python/keras/engine/sequential.py  | 18 +++----
 .../saving/saved_model/layer_serialization.py | 23 ++++-----
 .../saved_model/model_serialization_test.py   | 48 -------------------
 .../python/keras/utils/generic_utils.py       | 33 +++----------
 5 files changed, 25 insertions(+), 109 deletions(-)
 delete mode 100644 tensorflow/python/keras/saving/saved_model/model_serialization_test.py

diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 069586d3d59..d6fb60fd724 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -1857,18 +1857,6 @@ tf_py_test(
     ],
 )
 
-tf_py_test(
-    name = "model_serialization_test",
-    size = "medium",
-    srcs = ["saving/saved_model/model_serialization_test.py"],
-    additional_deps = [
-        ":keras",
-        "@absl_py//absl/testing:parameterized",
-        "//tensorflow/python/distribute:mirrored_strategy",
-        "//tensorflow/python:client_testlib",
-    ],
-)
-
 tf_py_test(
     name = "saving_utils_test",
     size = "medium",
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 522aed6aaa4..369cd31d656 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -345,17 +345,13 @@ class Sequential(training.Model):
     layer_configs = []
     for layer in self.layers:
       layer_configs.append(generic_utils.serialize_keras_object(layer))
-
-    if layer_configs and layer_configs[0]['config'] is not None:
-      # layer_configs[0]['config'] may be None only when saving SavedModel.
-
-      # Check to see whether the first non-input layer has the shape information
-      # to reconstruct `Sequential` as a graph network. If not, add it.
-      if (self._is_graph_network and
-          'batch_input_shape' not in layer_configs[0]['config'] and
-          isinstance(self._layers[0], input_layer.InputLayer)):
-        batch_input_shape = self._layers[0]._batch_input_shape
-        layer_configs[0]['config']['batch_input_shape'] = batch_input_shape
+    # When constructed using an `InputLayer` the first non-input layer may not
+    # have the shape information to reconstruct `Sequential` as a graph network.
+    if (self._is_graph_network and layer_configs and
+        'batch_input_shape' not in layer_configs[0]['config'] and
+        isinstance(self._layers[0], input_layer.InputLayer)):
+      batch_input_shape = self._layers[0]._batch_input_shape
+      layer_configs[0]['config']['batch_input_shape'] = batch_input_shape
 
     config = {
         'name': self.name,
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index ab1edaab585..054a01e1db0 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -23,7 +23,7 @@ from tensorflow.python.keras.saving.saved_model import base_serialization
 from tensorflow.python.keras.saving.saved_model import constants
 from tensorflow.python.keras.saving.saved_model import save_impl
 from tensorflow.python.keras.saving.saved_model import serialized_attributes
-from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
 from tensorflow.python.util import nest
 
 
@@ -51,22 +51,23 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
         expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
         dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
         batch_input_shape=getattr(self.obj, '_batch_input_shape', None))
-
-    with generic_utils.skip_failed_serialization():
-      # Store the config dictionary, which may be used when reviving the object.
-      # When loading, the program will attempt to revive the object from config,
-      # and if that fails, the object will be revived from the SavedModel.
-      config = generic_utils.serialize_keras_object(self.obj)['config']
-      if config is not None:
-        metadata['config'] = config
+    try:
+      # Store the config dictionary, which is only used by the revived object
+      # to return the original config when revived_obj.get_config() is called.
+      # It is not important for recreating the revived object.
+      metadata['config'] = self.obj.get_config()
+    except NotImplementedError:
+      # in the case of a subclassed model, the get_config() method will throw
+      # a NotImplementedError.
+      pass
     if self.obj.input_spec is not None:
       # Layer's input_spec has already been type-checked in the property setter.
       metadata['input_spec'] = nest.map_structure(
-          lambda x: generic_utils.serialize_keras_object(x) if x else None,
+          lambda x: None if x is None else serialize_keras_object(x),
           self.obj.input_spec)
     if (self.obj.activity_regularizer is not None and
         hasattr(self.obj.activity_regularizer, 'get_config')):
-      metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
+      metadata['activity_regularizer'] = serialize_keras_object(
           self.obj.activity_regularizer)
     return metadata
 
diff --git a/tensorflow/python/keras/saving/saved_model/model_serialization_test.py b/tensorflow/python/keras/saving/saved_model/model_serialization_test.py
deleted file mode 100644
index 125ab2fd958..00000000000
--- a/tensorflow/python/keras/saving/saved_model/model_serialization_test.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-# pylint: disable=protected-access
-"""Unit tests for serializing Keras models."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python import keras
-from tensorflow.python.keras import keras_parameterized
-from tensorflow.python.keras import testing_utils
-from tensorflow.python.platform import test
-
-
-class CustomLayer(keras.layers.Layer):
-
-  def __init__(self, unused_a):
-    super(CustomLayer, self).__init__()
-
-
-class ModelSerializationTest(keras_parameterized.TestCase):
-
-  @keras_parameterized.run_with_all_model_types(exclude_models=['subclass'])
-  def test_model_config_always_saved(self):
-    layer = CustomLayer(None)
-    with self.assertRaisesRegexp(NotImplementedError,
-                                 'must override `get_config`.'):
-      layer.get_config()
-    model = testing_utils.get_model_from_layers([layer], input_shape=(3,))
-    properties = model._trackable_saved_model_saver.python_properties
-    self.assertIsNotNone(properties['config'])
-
-
-if __name__ == '__main__':
-  test.main()
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index 8b899dc0c74..8ff27a38d77 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -30,7 +30,6 @@ import numpy as np
 import six
 
 from tensorflow.python.util import nest
-from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import keras_export
@@ -38,11 +37,6 @@ from tensorflow.python.util.tf_export import keras_export
 _GLOBAL_CUSTOM_OBJECTS = {}
 _GLOBAL_CUSTOM_NAMES = {}
 
-# Flag that determines whether to skip the NotImplementedError when calling
-# get_config in custom models and layers. This is only enabled when saving to
-# SavedModel, when the config isn't required.
-_SKIP_FAILED_SERIALIZATION = False
-
 
 @keras_export('keras.utils.CustomObjectScope')
 class CustomObjectScope(object):
@@ -193,17 +187,6 @@ def _get_name_or_custom_name(obj):
     return obj.__name__
 
 
-@tf_contextlib.contextmanager
-def skip_failed_serialization():
-  global _SKIP_FAILED_SERIALIZATION
-  prev = _SKIP_FAILED_SERIALIZATION
-  try:
-    _SKIP_FAILED_SERIALIZATION = True
-    yield
-  finally:
-    _SKIP_FAILED_SERIALIZATION = prev
-
-
 @keras_export('keras.utils.serialize_keras_object')
 def serialize_keras_object(instance):
   """Serialize Keras object into JSON."""
@@ -212,13 +195,7 @@ def serialize_keras_object(instance):
     return None
 
   if hasattr(instance, 'get_config'):
-    name = _get_name_or_custom_name(instance.__class__)
-    try:
-      config = instance.get_config()
-    except NotImplementedError as e:
-      if _SKIP_FAILED_SERIALIZATION:
-        return serialize_keras_class_and_config(name, None)
-      raise e
+    config = instance.get_config()
     serialization_config = {}
     for key, item in config.items():
       if isinstance(item, six.string_types):
@@ -234,13 +211,15 @@ def serialize_keras_object(instance):
         serialization_config[key] = serialized_item
       except ValueError:
         serialization_config[key] = item
+
+    name = _get_name_or_custom_name(instance.__class__)
     return serialize_keras_class_and_config(name, serialization_config)
   if hasattr(instance, '__name__'):
     return _get_name_or_custom_name(instance)
   raise ValueError('Cannot serialize', instance)
 
 
-def get_custom_objects_by_name(item, custom_objects=None):
+def _get_custom_objects_by_name(item, custom_objects=None):
   """Returns the item if it is in either local or global custom objects."""
   if item in _GLOBAL_CUSTOM_OBJECTS:
     return _GLOBAL_CUSTOM_OBJECTS[item]
@@ -281,7 +260,7 @@ def class_and_config_for_serialized_keras_object(
           printable_module_name='config_item')
     elif (isinstance(item, six.string_types) and
           tf_inspect.isfunction(
-              get_custom_objects_by_name(item, custom_objects))):
+              _get_custom_objects_by_name(item, custom_objects))):
       # Handle custom functions here. When saving functions, we only save the
       # function's name as a string. If we find a matching string in the custom
       # objects during deserialization, we convert the string back to the
@@ -290,7 +269,7 @@ def class_and_config_for_serialized_keras_object(
       # conflict with a custom function name, but this should be a rare case.
       # This issue does not occur if a string field has a naming conflict with
       # a custom object, since the config of an object will always be a dict.
-      deserialized_objects[key] = get_custom_objects_by_name(
+      deserialized_objects[key] = _get_custom_objects_by_name(
           item, custom_objects)
   for key, item in deserialized_objects.items():
     cls_config[key] = deserialized_objects[key]

From 31e7db70623f91591d3d936ade23735e9abf543a Mon Sep 17 00:00:00 2001
From: Tong Shen 
Date: Tue, 3 Dec 2019 21:03:44 -0800
Subject: [PATCH 269/279] Add use_sharding_op parameter to xla_sharding.tile().

PiperOrigin-RevId: 283680850
Change-Id: Ida5df90d1baa379fc714b233b639462cf0b09e39
---
 .../experimental/xla_sharding/xla_sharding.py    | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
index 52d6444e5bb..64c85b37504 100644
--- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
+++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
@@ -188,7 +188,21 @@ def assign_device(tensor, device, assign_tuple_sharding=False):
   return tensor
 
 
-def tile(tensor, tile_assignment, assign_tuple_sharding=False):
+def tile(tensor,
+         tile_assignment,
+         assign_tuple_sharding=False,
+         use_sharding_op=False):
+  """Returns a tensor that has tiled sharding.
+
+  Args:
+    tensor: A tf.Tensor to shard.
+    tile_assignment: An np.ndarray describing the topology of the tiling and
+      which device will compute which part of the topology.
+    assign_tuple_sharding: If the sharding type should be a tuple.
+    use_sharding_op: If true, adds a sharding op to set the sharding.
+  """
+  if use_sharding_op:
+    tensor = tf2xla.sharding(tensor)
   Sharding.tile(tile_assignment).apply_to_tensor(
       tensor,
       assign_tuple_sharding=assign_tuple_sharding

From a9458063f4505e86b64ca235e3d7556697609791 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 21:11:46 -0800
Subject: [PATCH 270/279] Make DistributedIterator iterable.

All Python iterators should themselves be iterables (just returning `self`). This is already the case for "OwnedIterator" (return type of iter(dataset) for non-distributed datasets), but not for DistributedIterator.

PiperOrigin-RevId: 283681632
Change-Id: I98aa4d950d6ecc0a3c9c5cefa72c05847cc9a3c2
---
 tensorflow/python/distribute/input_lib.py     |  3 +++
 .../python/distribute/input_lib_test.py       | 21 +++++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index f1f9a0e872d..80d03ed438a 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -278,6 +278,9 @@ class DistributedIterator(object):
     except errors.OutOfRangeError:
       raise StopIteration
 
+  def __iter__(self):
+    return self
+
   def get_next(self, name=None):
     """Returns the next input from the iterator for all replicas."""
     if not self._enable_get_next_as_optional:
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index 96363053219..1cca10a77a2 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -417,6 +417,27 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
         expected_values,
         distribution)
 
+  @combinations.generate(
+      combinations.combine(
+          mode=["eager"],
+          distribution=[
+              strategy_combinations.one_device_strategy,
+              strategy_combinations.mirrored_strategy_with_one_cpu
+          ]))
+  def testIterableIterator(self, distribution):
+    worker_device_pairs = [("", ["/device:CPU:0"])]
+    devices = nest.flatten([ds for _, ds in worker_device_pairs])
+    device_map = values.ReplicaDeviceMap(devices)
+    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
+
+    dataset = dataset_ops.DatasetV2.range(10)
+    dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
+                                                     distribution)
+
+    iterator = iter(dist_dataset)
+    for i, element in enumerate(iterator):
+      self.assertEqual(i, element.numpy())
+
   @combinations.generate(
       combinations.combine(
           mode=["graph", "eager"],

From 28467d1e2e4f03cbc306715f7bfea0036364d319 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" 
Date: Tue, 3 Dec 2019 21:19:10 -0800
Subject: [PATCH 271/279] Internal change

PiperOrigin-RevId: 283682281
Change-Id: Ie9ffc9fd5adbb91dfca07df6500a6a52f88f046e
---
 third_party/mlir/bindings/python/BUILD | 9 +--------
 1 file changed, 1 insertion(+), 8 deletions(-)

diff --git a/third_party/mlir/bindings/python/BUILD b/third_party/mlir/bindings/python/BUILD
index f9941ca1336..64ade7f43e2 100644
--- a/third_party/mlir/bindings/python/BUILD
+++ b/third_party/mlir/bindings/python/BUILD
@@ -7,14 +7,7 @@ licenses(["notice"])  # Apache 2.0
 exports_files(["BUILD"])
 
 package(
-    default_visibility = [":friends"],
-)
-
-package_group(
-    name = "friends",
-    packages = [
-        "@local_config_mlir//bindings/...",
-    ],
+    default_visibility = ["@local_config_mlir//:friends"],
 )
 
 #

From 9b423a8fc6585b89a511d4e76cfc0af31700ea9e Mon Sep 17 00:00:00 2001
From: Smit Hinsu 
Date: Tue, 3 Dec 2019 21:25:02 -0800
Subject: [PATCH 272/279] Avoid variable name conflict in MLIR tutorial code
 snippet

PiperOrigin-RevId: 283682865
Change-Id: Idfe434a83b91ac48117b63524cf9fc0379daa3a7
---
 third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
index c23597244dd..d797624ed72 100755
--- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
+++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md
@@ -270,16 +270,17 @@ types, etc). We can always get an instance of our toy operation by using LLVM's
 casting infrastructure:
 
 ```c++
-void processConstantOp(mlir::Operation *op) {
-  ConstantOp op = llvm::dyn_cast(op);
+void processConstantOp(mlir::Operation *operation) {
+  ConstantOp op = llvm::dyn_cast(operation);
 
   // This operation is not an instance of `ConstantOp`.
   if (!op)
     return;
 
   // Get the internal operation instance back.
-  mlir::Operation *internalOp = op.getOperation();
-  assert(internalOp == op && "these operation instances are the same");
+  mlir::Operation *internalOperation = op.getOperation();
+  assert(internalOperation == operation &&
+         "these operation instances are the same");
 }
 ```
 

From 6faa3bcd4c0acd5c617aab34e21c590b9fb5eb61 Mon Sep 17 00:00:00 2001
From: Xunkai Zhang 
Date: Tue, 3 Dec 2019 21:28:52 -0800
Subject: [PATCH 273/279] Do not allow reshape (even with equal flat size)
 fixed size TensorBuffers.

PiperOrigin-RevId: 283683221
Change-Id: Ia9547c9f8631fcae1175ccdb403e1e84992013eb
---
 .../lite/support/tensorbuffer/TensorBuffer.java          | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
index 286f91f5037..ea6a085a3bc 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
@@ -61,10 +61,7 @@ public abstract class TensorBuffer {
    * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
    * 
* - *

The size of a fixed-size TensorBuffer cannot be changed once it is created. However, loading - * arraies or data buffers of the same buffer size but different shapes is allowed. - * - *

TODO(b/139782181): Shall we make it fixed-size or fixed-shape? + *

The size of a fixed-size TensorBuffer cannot be changed once it is created. * * @param shape The shape of the {@link TensorBuffer} to be created. * @param dataType The dataType of the {@link TensorBuffer} to be created. @@ -87,7 +84,7 @@ public abstract class TensorBuffer { * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the * created {@link TensorBuffer} is {0}. * - *

Dynamic TensorBuffers will reallocate memory when Loading arraies or data buffers of + *

Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of * different buffer sizes. * * @param dataType The dataType of the {@link TensorBuffer} to be created. @@ -326,7 +323,7 @@ public abstract class TensorBuffer { allocateMemory(shape); } else { // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. - SupportPreconditions.checkArgument(flatSize == computeFlatSize(shape)); + SupportPreconditions.checkArgument(Arrays.equals(shape, this.shape)); this.shape = shape.clone(); } } From d211a7eeaa65f8e2b24396be79d1d54bcb5ba806 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Tue, 3 Dec 2019 21:48:26 -0800 Subject: [PATCH 274/279] Add simple 0/1 arithmetic optimizations Adding or subtracting a scalar 0 or multiply/divide by a scalar 1 should just return the first argument. PiperOrigin-RevId: 283685484 Change-Id: I1cf4328de2fc16d12d9d73817094229ad9ae13e1 --- tensorflow/python/ops/math_ops.py | 31 +++++++++++++---- tensorflow/python/ops/math_ops_test.py | 48 ++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 078219e2f23..8473ea9aa96 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -341,12 +341,19 @@ def divide(x, y, name=None): # override names. Use a dummy class to track the runtime division behavior return DivideDelegateWithName(x, name) / y else: + # Do an is comparison here since this is cheaper than isinstance or __eq__ + if y is 1: # pylint: disable=literal-comparison + return x return x / y @tf_export("math.multiply", "multiply") @dispatch.add_dispatch_support -def multiply(x, y, name=None): +def multiply(x, y, name=None): # pylint: disable=missing-docstring + # Do an is comparison here since this is cheaper than isinstance or __eq__ + if y is 1: # pylint: disable=literal-comparison + return x + return gen_math_ops.mul(x, y, name) @@ -358,16 +365,28 @@ multiply.__doc__ = gen_math_ops.mul.__doc__.replace("Multiply", "tf.multiply") "2016-12-30", "`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`") def _mul(x, y, name=None): - return gen_math_ops.mul(x, y, name) + return multiply(x, y, name) _mul.__doc__ = ( gen_math_ops.mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__)) +def add_v2(x, y, name=None): + # Do an is comparison here since this is cheaper than isinstance or __eq__ + if y is 0: # pylint: disable=literal-comparison + return x + + return gen_math_ops.add_v2(x, y, name=name) + + @tf_export("math.subtract", "subtract") @dispatch.add_dispatch_support def subtract(x, y, name=None): + # Do an is comparison here since this is cheaper than isinstance or __eq__ + if y is 0: # pylint: disable=literal-comparison + return x + return gen_math_ops.sub(x, y, name) @@ -379,7 +398,7 @@ subtract.__doc__ = gen_math_ops.sub.__doc__.replace("`Sub`", "`tf.subtract`") "2016-12-30", "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`") def _sub(x, y, name=None): - return gen_math_ops.sub(x, y, name) + return subtract(x, y, name) _sub.__doc__ = ( @@ -1207,14 +1226,14 @@ def _add_dispatch(x, y, name=None): if x.dtype == dtypes.string: return gen_math_ops.add(x, y, name=name) else: - return gen_math_ops.add_v2(x, y, name=name) + return add_v2(x, y, name=name) def _mul_dispatch(x, y, name=None): """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse".""" is_tensor_y = isinstance(y, ops.Tensor) if is_tensor_y: - return gen_math_ops.mul(x, y, name=name) + return gen_math_ops.mul(x, y, name) else: assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse. new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values, @@ -1233,7 +1252,7 @@ _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul", sparse_tensor.SparseTensor) _OverrideBinaryOperatorHelper(_add_dispatch, "add") -_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub") +_OverrideBinaryOperatorHelper(subtract, "sub") _OverrideBinaryOperatorHelper(_mul_dispatch, "mul") _OverrideBinaryOperatorHelper(_div_python2, "div") _OverrideBinaryOperatorHelper(_truediv_python3, "truediv") diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index f49ba3dd2a3..834726749f8 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -689,5 +689,53 @@ class RangeTest(test_util.TensorFlowTestCase): self.assertAllEqual(values, self.evaluate(tensor)) +@test_util.run_all_in_graph_and_eager_modes +class ScalarOptimizationTest(test_util.TensorFlowTestCase): + + def testAddZero(self): + x = constant_op.constant(1) + y = math_ops.add_v2(x, 0) + self.assertAllEqual(x, y) + self.assertIs(x, y) + + # Optimization not applied + y = math_ops.add_v2(x, constant_op.constant(0)) + self.assertAllEqual(x, y) + self.assertIsNot(x, y) + + def testSubtractZero(self): + x = constant_op.constant(1) + y = math_ops.subtract(x, 0) + self.assertAllEqual(x, y) + self.assertIs(x, y) + + # Optimization not applied + y = math_ops.subtract(x, constant_op.constant(0)) + self.assertAllEqual(x, y) + self.assertIsNot(x, y) + + def testMultiplyOne(self): + x = constant_op.constant(1) + y = math_ops.multiply(x, 1) + self.assertAllEqual(x, y) + self.assertIs(x, y) + + # Optimization not applied + y = math_ops.multiply(x, constant_op.constant(1)) + self.assertAllEqual(x, y) + self.assertIsNot(x, y) + + def testDivideOne(self): + x = constant_op.constant(1) + y = math_ops.divide(x, 1) + self.assertAllEqual(x, y) + self.assertIs(x, y) + + # Optimization not applied + y = math_ops.divide(x, constant_op.constant(1)) + self.assertAllEqual(x, y) + self.assertIsNot(x, y) + + if __name__ == "__main__": googletest.main() From 4831a9f1de914535d93ee98664b64738cf46b624 Mon Sep 17 00:00:00 2001 From: Yanhui Liang Date: Tue, 3 Dec 2019 21:53:27 -0800 Subject: [PATCH 275/279] Remove duplicates and fix typo in saved model format. PiperOrigin-RevId: 283686133 Change-Id: I8005c91c029720349533eaf60562f095de0d4bbe --- tensorflow/python/keras/testing_utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index 146807028cb..aa4059cb50e 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -377,20 +377,10 @@ def saved_model_format_scope(value): _thread_local_data.saved_model_format = previous_value -def get_saved_model_format(): - """Gets the saved model format that should be tested.""" - if _thread_local_data.saved_model_format is None: - raise ValueError( - 'Cannot call `get_saved_model_format()` outside of a ' - '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' - 'decorator.') - return _thread_local_data.saved_model_format - - def get_save_format(): if _thread_local_data.saved_model_format is None: raise ValueError( - 'Cannot call `get_saved_model_format()` outside of a ' + 'Cannot call `get_save_format()` outside of a ' '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' 'decorator.') return _thread_local_data.saved_model_format From f505dc5a4967b17600bf9e12a7d5b32ea8c5a47f Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Tue, 3 Dec 2019 22:00:50 -0800 Subject: [PATCH 276/279] Updated broken link in TF Lite iOS getting started doc. PiperOrigin-RevId: 283687031 Change-Id: I4033496fa1519b8a1adbe04c704f440067da7e22 --- tensorflow/lite/g3doc/guide/ios.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/g3doc/guide/ios.md b/tensorflow/lite/g3doc/guide/ios.md index fc997bccf9d..0c7e5dc9c90 100644 --- a/tensorflow/lite/g3doc/guide/ios.md +++ b/tensorflow/lite/g3doc/guide/ios.md @@ -7,7 +7,7 @@ example: image classification example For an explanation of the source code, you should also read -[TensorFlow Lite iOS image classification](https://www.tensorflow.org/lite/models/image_classification/ios). +[TensorFlow Lite iOS image classification](https://www.tensorflow.org/code/py/tensorflow_examples/lite/examples/image_classification/ios/EXPLORE_THE_CODE.md). This example app uses [image classification](https://www.tensorflow.org/lite/models/image_classification/overview) From 92cb2db3dca8100400583d80d719099d421aa4b0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 3 Dec 2019 22:37:31 -0800 Subject: [PATCH 277/279] Add simple 0/1 arithmetic optimizations Adding or subtracting a scalar 0 or multiply/divide by a scalar 1 should just return the first argument. PiperOrigin-RevId: 283691465 Change-Id: I18e7ca94811303df2265ebcaa874d4a1bd2ad7f7 --- tensorflow/python/ops/math_ops.py | 31 ++++------------- tensorflow/python/ops/math_ops_test.py | 48 -------------------------- 2 files changed, 6 insertions(+), 73 deletions(-) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 8473ea9aa96..078219e2f23 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -341,19 +341,12 @@ def divide(x, y, name=None): # override names. Use a dummy class to track the runtime division behavior return DivideDelegateWithName(x, name) / y else: - # Do an is comparison here since this is cheaper than isinstance or __eq__ - if y is 1: # pylint: disable=literal-comparison - return x return x / y @tf_export("math.multiply", "multiply") @dispatch.add_dispatch_support -def multiply(x, y, name=None): # pylint: disable=missing-docstring - # Do an is comparison here since this is cheaper than isinstance or __eq__ - if y is 1: # pylint: disable=literal-comparison - return x - +def multiply(x, y, name=None): return gen_math_ops.mul(x, y, name) @@ -365,28 +358,16 @@ multiply.__doc__ = gen_math_ops.mul.__doc__.replace("Multiply", "tf.multiply") "2016-12-30", "`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`") def _mul(x, y, name=None): - return multiply(x, y, name) + return gen_math_ops.mul(x, y, name) _mul.__doc__ = ( gen_math_ops.mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__)) -def add_v2(x, y, name=None): - # Do an is comparison here since this is cheaper than isinstance or __eq__ - if y is 0: # pylint: disable=literal-comparison - return x - - return gen_math_ops.add_v2(x, y, name=name) - - @tf_export("math.subtract", "subtract") @dispatch.add_dispatch_support def subtract(x, y, name=None): - # Do an is comparison here since this is cheaper than isinstance or __eq__ - if y is 0: # pylint: disable=literal-comparison - return x - return gen_math_ops.sub(x, y, name) @@ -398,7 +379,7 @@ subtract.__doc__ = gen_math_ops.sub.__doc__.replace("`Sub`", "`tf.subtract`") "2016-12-30", "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`") def _sub(x, y, name=None): - return subtract(x, y, name) + return gen_math_ops.sub(x, y, name) _sub.__doc__ = ( @@ -1226,14 +1207,14 @@ def _add_dispatch(x, y, name=None): if x.dtype == dtypes.string: return gen_math_ops.add(x, y, name=name) else: - return add_v2(x, y, name=name) + return gen_math_ops.add_v2(x, y, name=name) def _mul_dispatch(x, y, name=None): """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse".""" is_tensor_y = isinstance(y, ops.Tensor) if is_tensor_y: - return gen_math_ops.mul(x, y, name) + return gen_math_ops.mul(x, y, name=name) else: assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse. new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values, @@ -1252,7 +1233,7 @@ _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul", sparse_tensor.SparseTensor) _OverrideBinaryOperatorHelper(_add_dispatch, "add") -_OverrideBinaryOperatorHelper(subtract, "sub") +_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub") _OverrideBinaryOperatorHelper(_mul_dispatch, "mul") _OverrideBinaryOperatorHelper(_div_python2, "div") _OverrideBinaryOperatorHelper(_truediv_python3, "truediv") diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 834726749f8..f49ba3dd2a3 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -689,53 +689,5 @@ class RangeTest(test_util.TensorFlowTestCase): self.assertAllEqual(values, self.evaluate(tensor)) -@test_util.run_all_in_graph_and_eager_modes -class ScalarOptimizationTest(test_util.TensorFlowTestCase): - - def testAddZero(self): - x = constant_op.constant(1) - y = math_ops.add_v2(x, 0) - self.assertAllEqual(x, y) - self.assertIs(x, y) - - # Optimization not applied - y = math_ops.add_v2(x, constant_op.constant(0)) - self.assertAllEqual(x, y) - self.assertIsNot(x, y) - - def testSubtractZero(self): - x = constant_op.constant(1) - y = math_ops.subtract(x, 0) - self.assertAllEqual(x, y) - self.assertIs(x, y) - - # Optimization not applied - y = math_ops.subtract(x, constant_op.constant(0)) - self.assertAllEqual(x, y) - self.assertIsNot(x, y) - - def testMultiplyOne(self): - x = constant_op.constant(1) - y = math_ops.multiply(x, 1) - self.assertAllEqual(x, y) - self.assertIs(x, y) - - # Optimization not applied - y = math_ops.multiply(x, constant_op.constant(1)) - self.assertAllEqual(x, y) - self.assertIsNot(x, y) - - def testDivideOne(self): - x = constant_op.constant(1) - y = math_ops.divide(x, 1) - self.assertAllEqual(x, y) - self.assertIs(x, y) - - # Optimization not applied - y = math_ops.divide(x, constant_op.constant(1)) - self.assertAllEqual(x, y) - self.assertIsNot(x, y) - - if __name__ == "__main__": googletest.main() From 90a754c965005fe33dfde1352267be92fea4c095 Mon Sep 17 00:00:00 2001 From: Tiezhen WANG Date: Wed, 4 Dec 2019 00:05:37 -0800 Subject: [PATCH 278/279] TFLM: nit catch memory allocation failures PiperOrigin-RevId: 283699825 Change-Id: I0434f161c1ac6d5578e7a35f26f7a8344cc3cf6f --- tensorflow/lite/experimental/micro/micro_allocator.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tensorflow/lite/experimental/micro/micro_allocator.cc b/tensorflow/lite/experimental/micro/micro_allocator.cc index 700016af510..82b3b350c23 100644 --- a/tensorflow/lite/experimental/micro/micro_allocator.cc +++ b/tensorflow/lite/experimental/micro/micro_allocator.cc @@ -89,6 +89,11 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model, reinterpret_cast(memory_allocator_.AllocateFromTail( sizeof(TfLiteTensor) * context_->tensors_size, alignof(TfLiteTensor))); + if (context_->tensors == nullptr) { + error_reporter_->Report( + "Failed to allocate memory for context->tensors, %d bytes required", + sizeof(TfLiteTensor) * context_->tensors_size); + } // Null all inputs so we can later perform a null check to avoid re-allocating // registered pre-allocated inputs. @@ -230,6 +235,12 @@ TfLiteStatus MicroAllocator::FinishTensorAllocation() { TensorInfo* tensor_info = reinterpret_cast(tmp_allocator.AllocateFromTail( sizeof(TensorInfo) * tensors_size, alignof(TensorInfo))); + if (tensor_info == nullptr) { + error_reporter_->Report( + "Failed to allocate memory for tensor_info, %d bytes required", + sizeof(TfLiteTensor) * context_->tensors_size); + return kTfLiteError; + } // Set up the runtime data structures for all tensors. for (size_t i = 0; i < tensors_size; ++i) { From 7c5157667006181f16efa3b70468ec1bd62cb070 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 4 Dec 2019 01:02:54 -0800 Subject: [PATCH 279/279] compat: Update forward compatibility horizon to 2019-12-04 PiperOrigin-RevId: 283706005 Change-Id: I8ad908d54089977d3a2cfadea24885693d7d3938 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 6c3d92593b2..71427c9c237 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 3) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 4) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None