diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 19b57f0be0b..61bcd6ab524 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -87,11 +87,7 @@ if(WIN32)
       # not working on windows yet
       "${tensorflow_source_dir}/tensorflow/core/kernels/depthwise_conv_op.cc"  # Cannot find symbol: tensorflow::LaunchConv2DOp<struct Eigen::ThreadPoolDevice, double>::launch(...).
       "${tensorflow_source_dir}/tensorflow/core/kernels/fact_op.cc"
-      "${tensorflow_source_dir}/tensorflow/core/kernels/immutable_constant_op.cc"
-      "${tensorflow_source_dir}/tensorflow/core/kernels/immutable_constant_op.h"
       "${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*"
-      "${tensorflow_source_dir}/tensorflow/core/kernels/sparse_matmul_op.cc"
-      "${tensorflow_source_dir}/tensorflow/core/kernels/sparse_matmul_op.h"
       "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h"
       "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.cc"
       "${tensorflow_source_dir}/tensorflow/core/kernels/svd*.cc"
diff --git a/tensorflow/core/kernels/immutable_constant_op.cc b/tensorflow/core/kernels/immutable_constant_op.cc
index b9abfdb04c4..0dd08c694eb 100644
--- a/tensorflow/core/kernels/immutable_constant_op.cc
+++ b/tensorflow/core/kernels/immutable_constant_op.cc
@@ -96,9 +96,9 @@ void ImmutableConstantOp::Compute(OpKernelContext* ctx) {
 }
 
 ImmutableConstantOp::~ImmutableConstantOp() {}
-constexpr char ImmutableConstantOp::kDTypeAttr[];
-constexpr char ImmutableConstantOp::kShapeAttr[];
-constexpr char ImmutableConstantOp::kMemoryRegionNameAttr[];
+constexpr char const* ImmutableConstantOp::kDTypeAttr;
+constexpr char const* ImmutableConstantOp::kShapeAttr;
+constexpr char const* ImmutableConstantOp::kMemoryRegionNameAttr;
 
 REGISTER_KERNEL_BUILDER(Name("ImmutableConst").Device(DEVICE_CPU),
                         ImmutableConstantOp);
diff --git a/tensorflow/core/kernels/immutable_constant_op.h b/tensorflow/core/kernels/immutable_constant_op.h
index 2cf9c673089..795331b4b25 100644
--- a/tensorflow/core/kernels/immutable_constant_op.h
+++ b/tensorflow/core/kernels/immutable_constant_op.h
@@ -33,9 +33,9 @@ class ImmutableConstantOp : public OpKernel {
   ~ImmutableConstantOp() override;
 
   // Names of attributes that are used by this op
-  static constexpr char kDTypeAttr[] = "dtype";
-  static constexpr char kShapeAttr[] = "shape";
-  static constexpr char kMemoryRegionNameAttr[] = "memory_region_name";
+  static constexpr char const* kDTypeAttr = "dtype";
+  static constexpr char const* kShapeAttr = "shape";
+  static constexpr char const* kMemoryRegionNameAttr = "memory_region_name";
 
  private:
   string region_name_;
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
index e5b0b6fcd21..c5460c8db17 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -34,6 +34,7 @@ limitations under the License.
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
 
+
 namespace tensorflow {
 
 namespace {
@@ -134,7 +135,7 @@ struct SparseSlice {
 
 template <typename T>
 template <bool Transpose>
-void SparseSlice<T>::Initialize(const SparseSlice<T>::ConstMatrixMap& mat,
+void SparseSlice<T>::Initialize(const typename SparseSlice<T>::ConstMatrixMap& mat,
                                 int col_offset) {
   const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
   const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
@@ -950,7 +951,7 @@ class SparseMatMulOp : public OpKernel {
 template <typename TL, typename TR>
 inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
     const std::vector<SparseSlice<TL>*>& left,
-    const SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
+    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
     int output_row_offset, int output_col_offset, bool assign,
     bool transpose_output, MatrixMap* output) {
   static const Eigen::array<int, 2> perm({1, 0});
@@ -1000,7 +1001,7 @@ inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
 
 template <typename TL, typename TR>
 inline BlockingCounter* SparseMatMul<TL, TR>::CreateSparseSlices(
-    const SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
+    const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
     int slice_num_rows, int slice_block_size, int slice_num_cols,
     std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
     const DeviceBase::CpuWorkerThreads* thread_pool) {
@@ -1096,7 +1097,7 @@ ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
 
 template <typename TL, typename TR>
 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
-    const SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int slice_row_start,
+    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int slice_row_start,
     int slice_num_rows, int slice_col_start, int slice_num_cols, const int N,
     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
   DCHECK_EQ(N % 2, 0);
@@ -1153,7 +1154,7 @@ inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
 template <typename TL, typename TR>
 inline void SparseMatMul<TL, TR>::SliceMatrix(
     const MatrixR& mat, const int num_rows, const int num_slices,
-    std::vector<SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
+    std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
   slices->resize(num_slices);
   DSizes d(num_rows, mat.dimension(1));
   DCHECK_LE(num_rows * num_slices, mat.dimension(0));
@@ -1164,10 +1165,10 @@ inline void SparseMatMul<TL, TR>::SliceMatrix(
 
 template <typename TL, typename TR>
 inline BlockingCounter* SparseMatMul<TL, TR>::CreateDenseSlices(
-    const SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
+    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
     int num_rows, int col_start, int num_cols,
     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
-    std::vector<SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
+    std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
   BlockingCounter* shuffle_counter = ShuffleMatrix(
       mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer);
   const int num_slices = (num_cols + N - 1) / N;
@@ -1177,8 +1178,8 @@ inline BlockingCounter* SparseMatMul<TL, TR>::CreateDenseSlices(
 
 template <typename TL, typename TR>
 inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
-    const SparseMatMul<TL, TR>::ConstMatrixMapL& left,
-    const SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left,
+    const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
+    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left,
     int num_threads, int* KR, int* NR, int* KL, int* JB, int* IB) {
   // Heuristics for calculating block sizes
   // Assume two hyperthreads per core.
@@ -1248,8 +1249,8 @@ inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
 //    {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
 template <typename TL, typename TR>
 inline void SparseMatMul<TL, TR>::Compute(
-    const SparseMatMul<TL, TR>::ConstMatrixMapL& left,
-    const SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left,
+    const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
+    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left,
     const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_output,
     MatrixMap* output) {
   const int num_threads = thread_pool->num_threads;
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h
index 97dd285fbf9..4e14f0099ab 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_matmul_op.h
@@ -19,6 +19,10 @@ limitations under the License.
 #include "third_party/eigen3/Eigen/Core"
 #include "tensorflow/core/platform/types.h"
 
+#if defined(PLATFORM_WINDOWS)
+#include "tensorflow/core/platform/windows/intrinsics_port.h"
+#endif
+
 namespace Eigen {
 namespace internal {
 
diff --git a/tensorflow/core/platform/windows/intrinsics_port.h b/tensorflow/core/platform/windows/intrinsics_port.h
new file mode 100644
index 00000000000..df4d4862213
--- /dev/null
+++ b/tensorflow/core/platform/windows/intrinsics_port.h
@@ -0,0 +1,41 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_INTRINSICS_PORT_H_
+#define TENSORFLOW_CORE_PLATFORM_WINDOWS_INTRINSICS_PORT_H_
+
+
+#ifdef _MSC_VER
+// the following avx intrinsics are not defined on windows
+// in immintrin.h so we define them here.
+// 
+#include "tensorflow/core/platform/types.h"
+
+#define _mm_load_pd1 _mm_load1_pd
+static inline int
+_mm256_extract_epi32(__m256i a, const int i)
+{
+  return a.m256i_i32[i & 7];
+}
+
+static inline __m256i
+_mm256_insert_epi32(__m256i a, int b, const int i)
+{
+  __m256i c = a;
+  c.m256i_i32[i & 7] = b;
+  return c;
+}
+#endif
+#endif