Add KafkaReader for processing streaming data with Apache Kafka (#14098)

* Add KafkaReader for processing streaming data with Apache Kafka

Apache Kafka is a widely used distributed streaming platform in
open source community. The goal of this fix is to create a contrib
Reader ops (inherits ReaderBase and is similiar to
TextLineReader/TFRecordReader) so that it is possible to reader
Kafka streaming data from TensorFlow in a similiar fashion.

This fix uses a C/C++ Apache Kafka client library librdkafka which
is released under the 2-clause BSD license, and is widely used in
a number of Kafka bindings such as Go, Python, C#/.Net, etc.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add KafkaReader Python wrapper.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add BUILD file and op registration for KafkaReader.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add C++ Kernel for KafkaReader

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add librdkafka to third_party packages in Bazel

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add contrib/kafka to part of the contrib bazel file.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Update workspace.bzl

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Comment out clean_deps of `tensorflow/core:framework` and `tensorflow/core:lib`

so that it is possible to build with ReaderBase.

See 1419 for details.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add group id flag.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Sync offset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test cases and scipt to start and stop Kafka server (with docker)

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Convert to KafkaConsumer from the legacy Consumer with librdkafka

so that thread join does not hang.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Only output offset as the key.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add timeout attr so that Kafka Consumer could use

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Build Kafka kernels by default, so that to get around the linkage issue.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Convert KafkaReader to KafkaDataset.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix workspace.bzl for kafka with tf_http_archive

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add public visibility

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Address review feedbacks

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Optionally select Kafka support through ./configure

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
diff --git a/configure.py b/configure.py
index 083fed1..16763b8 100644
--- a/configure.py
+++ b/configure.py
@@ -1354,6 +1354,7 @@
     environ_cp['TF_NEED_GCP'] = '0'
     environ_cp['TF_NEED_HDFS'] = '0'
     environ_cp['TF_NEED_JEMALLOC'] = '0'
+    environ_cp['TF_NEED_KAFKA'] = '0'
     environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
     environ_cp['TF_NEED_COMPUTECPP'] = '0'
     environ_cp['TF_NEED_OPENCL'] = '0'
@@ -1372,6 +1373,8 @@
                 'with_hdfs_support', True, 'hdfs')
   set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
                 'with_s3_support', True, 's3')
+  set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
+                'with_kafka_support', False, 'kafka')
   set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
                 False, 'xla')
   set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index b26c525..9e69613 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -211,6 +211,12 @@
     visibility = ["//visibility:public"],
 )
 
+config_setting(
+    name = "with_kafka_support",
+    define_values = {"with_kafka_support": "true"},
+    visibility = ["//visibility:public"],
+)
+
 # Crosses between platforms and file system libraries not supported on those
 # platforms due to limitations in nested select() statements.
 config_setting(
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index f1e5443..5ac5955 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -48,6 +48,7 @@
         "//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
         "//tensorflow/contrib/input_pipeline:input_pipeline_py",
         "//tensorflow/contrib/integrate:integrate_py",
+        "//tensorflow/contrib/kafka",
         "//tensorflow/contrib/keras",
         "//tensorflow/contrib/kernel_methods",
         "//tensorflow/contrib/kfac",
@@ -139,6 +140,7 @@
         "//tensorflow/contrib/factorization:all_ops",
         "//tensorflow/contrib/framework:all_ops",
         "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
+        "//tensorflow/contrib/kafka:kafka_ops_op_lib",
         "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
         "//tensorflow/contrib/nccl:nccl_ops_op_lib",
         "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib",
diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD
new file mode 100644
index 0000000..f7593aa
--- /dev/null
+++ b/tensorflow/contrib/kafka/BUILD
@@ -0,0 +1,104 @@
+package(
+    default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+tf_kernel_library(
+    name = "kafka_kernels",
+    srcs = ["kernels/kafka_dataset_ops.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core/kernels:bounds_check_lib",
+        "//tensorflow/core/kernels:dataset",
+        "//third_party/eigen3",
+        "@kafka//:kafka",
+    ],
+)
+
+tf_gen_op_libs(
+    op_lib_names = ["kafka_ops"],
+    deps = [
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_gen_op_wrapper_py(
+    name = "gen_kafka_ops",
+    out = "python/ops/gen_kafka_ops.py",
+    require_shape_functions = True,
+    deps = [":kafka_ops_op_lib"],
+)
+
+py_library(
+    name = "kafka",
+    srcs = [
+        "__init__.py",
+        "python/ops/kafka_dataset_ops.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//visibility:public"],
+    deps = [
+        ":gen_kafka_ops",
+        "//tensorflow/contrib/util:util_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:state_ops",
+        "//tensorflow/python:training",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/data/ops:iterator_ops",
+        "//tensorflow/python/data/ops:readers",
+    ],
+)
+
+# The Kafka server has to be setup before running the test.
+# The Kafka server is setup through Docker so the Docker engine
+# has to be installed.
+#
+# Once the Docker engine is ready:
+# To setup the Kafka server:
+# $ bash tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh start kafka
+#
+# After the test is complete:
+# To team down the Kafka server:
+# $ bash tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh stop kafka
+tf_py_test(
+    name = "kafka_test",
+    srcs = ["python/kernel_tests/kafka_test.py"],
+    additional_deps = [
+        ":kafka",
+        "//third_party/py/numpy",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+    tags = [
+        "manual",
+    ],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/kafka/__init__.py b/tensorflow/contrib/kafka/__init__.py
new file mode 100644
index 0000000..4d755c4
--- /dev/null
+++ b/tensorflow/contrib/kafka/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kafka Dataset.
+
+@@KafkaDataset
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kafka.python.ops.kafka_dataset_ops import KafkaDataset
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+    "KafkaDataset",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
new file mode 100644
index 0000000..88ef5f3
--- /dev/null
+++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
@@ -0,0 +1,321 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/dataset.h"
+
+#include "tensorflow/core/framework/tensor.h"
+
+#include "src-cpp/rdkafkacpp.h"
+
+namespace tensorflow {
+
+class KafkaDatasetOp : public DatasetOpKernel {
+ public:
+  using DatasetOpKernel::DatasetOpKernel;
+
+  void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+    const Tensor* topics_tensor;
+    OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor));
+    OP_REQUIRES(
+        ctx, topics_tensor->dims() <= 1,
+        errors::InvalidArgument("`topics` must be a scalar or a vector."));
+
+    std::vector<string> topics;
+    topics.reserve(topics_tensor->NumElements());
+    for (int i = 0; i < topics_tensor->NumElements(); ++i) {
+      topics.push_back(topics_tensor->flat<string>()(i));
+    }
+
+    std::string servers = "";
+    OP_REQUIRES_OK(ctx,
+                   ParseScalarArgument<std::string>(ctx, "servers", &servers));
+    std::string group = "";
+    OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "group", &group));
+    bool eof = false;
+    OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "eof", &eof));
+    int64 timeout = -1;
+    OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "timeout", &timeout));
+    OP_REQUIRES(ctx, (timeout > 0),
+                errors::InvalidArgument(
+                    "Timeout value should be large than 0, got ", timeout));
+    *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout);
+  }
+
+ private:
+  class Dataset : public GraphDatasetBase {
+   public:
+    Dataset(OpKernelContext* ctx, std::vector<string> topics,
+            const string& servers, const string& group, const bool eof,
+            const int64 timeout)
+        : GraphDatasetBase(ctx),
+          topics_(std::move(topics)),
+          servers_(servers),
+          group_(group),
+          eof_(eof),
+          timeout_(timeout) {}
+
+    std::unique_ptr<IteratorBase> MakeIterator(
+        const string& prefix) const override {
+      return std::unique_ptr<IteratorBase>(
+          new Iterator({this, strings::StrCat(prefix, "::Kafka")}));
+    }
+
+    const DataTypeVector& output_dtypes() const override {
+      static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+      return *dtypes;
+    }
+
+    const std::vector<PartialTensorShape>& output_shapes() const override {
+      static std::vector<PartialTensorShape>* shapes =
+          new std::vector<PartialTensorShape>({{}});
+      return *shapes;
+    }
+
+    string DebugString() override { return "KafkaDatasetOp::Dataset"; }
+
+   protected:
+    Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      Node* topics = nullptr;
+      TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics));
+      Node* servers = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers));
+      Node* group = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(group_, &group));
+      Node* eof = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof));
+      Node* timeout = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout));
+      TF_RETURN_IF_ERROR(
+          b->AddDataset(this, {topics, servers, group, eof, timeout}, output));
+      return Status::OK();
+    }
+
+   private:
+    class Iterator : public DatasetIterator<Dataset> {
+     public:
+      explicit Iterator(const Params& params)
+          : DatasetIterator<Dataset>(params) {}
+
+      Status GetNextInternal(IteratorContext* ctx,
+                             std::vector<Tensor>* out_tensors,
+                             bool* end_of_sequence) override {
+        mutex_lock l(mu_);
+        do {
+          // We are currently processing a topic, so try to read the next line.
+          if (consumer_.get()) {
+            while (true) {
+              if (limit_ >= 0 &&
+                  (topic_partition_->offset() >= limit_ || offset_ >= limit_)) {
+                // EOF current topic
+                break;
+              }
+              std::unique_ptr<RdKafka::Message> message(
+                  consumer_->consume(dataset()->timeout_));
+              if (message->err() == RdKafka::ERR_NO_ERROR) {
+                // Produce the line as output.
+                Tensor line_tensor(cpu_allocator(), DT_STRING, {});
+                line_tensor.scalar<string>()() =
+                    std::string(static_cast<const char*>(message->payload()),
+                                message->len());
+                out_tensors->emplace_back(std::move(line_tensor));
+                *end_of_sequence = false;
+                // Sync offset
+                offset_ = message->offset();
+                return Status::OK();
+              }
+
+              if (message->err() == RdKafka::ERR__PARTITION_EOF &&
+                  dataset()->eof_) {
+                // EOF current topic
+                break;
+              }
+              if (message->err() != RdKafka::ERR__TIMED_OUT) {
+                return errors::Internal("Failed to consume:",
+                                        message->errstr());
+              }
+              message.reset(nullptr);
+              consumer_->poll(0);
+            }
+
+            // We have reached the end of the current topic, so maybe
+            // move on to next topic.
+            ResetStreamsLocked();
+            ++current_topic_index_;
+          }
+
+          // Iteration ends when there are no more topic to process.
+          if (current_topic_index_ == dataset()->topics_.size()) {
+            *end_of_sequence = true;
+            return Status::OK();
+          }
+
+          TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+        } while (true);
+      }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"),
+                                               current_topic_index_));
+
+        // `consumer_` is empty if
+        // 1. GetNext has not been called even once.
+        // 2. All topics have been read and iterator has been exhausted.
+        if (consumer_.get()) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("current_pos"), offset_));
+        }
+        return Status::OK();
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        ResetStreamsLocked();
+        int64 current_topic_index;
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"),
+                                              &current_topic_index));
+        current_topic_index_ = size_t(current_topic_index);
+        // The key "current_pos" is written only if the iterator was saved
+        // with an open topic.
+        if (reader->Contains(full_name("current_pos"))) {
+          int64 current_pos;
+          TF_RETURN_IF_ERROR(
+              reader->ReadScalar(full_name("current_pos"), &current_pos));
+
+          TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+          topic_partition_->set_offset(current_pos);
+          if (topic_partition_->offset() != current_pos) {
+            return errors::Internal("Failed to restore to offset ",
+                                    current_pos);
+          }
+          offset_ = current_pos;
+        }
+        return Status::OK();
+      }
+
+     private:
+      // Sets up Kafka streams to read from the topic at
+      // `current_topic_index_`.
+      Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        if (current_topic_index_ >= dataset()->topics_.size()) {
+          return errors::InvalidArgument(
+              "current_topic_index_:", current_topic_index_,
+              " >= topics_.size():", dataset()->topics_.size());
+        }
+
+        // Actually move on to next topic.
+        string entry = dataset()->topics_[current_topic_index_];
+
+        std::vector<string> parts = str_util::Split(entry, ":");
+        if (parts.size() < 1) {
+          return errors::InvalidArgument("Invalid parameters: ", entry);
+        }
+        string topic = parts[0];
+        int32 partition = 0;
+        if (parts.size() > 1) {
+          if (!strings::safe_strto32(parts[1], &partition)) {
+            return errors::InvalidArgument("Invalid parameters: ", entry);
+          }
+        }
+        int64 offset = 0;
+        if (parts.size() > 2) {
+          if (!strings::safe_strto64(parts[2], &offset)) {
+            return errors::InvalidArgument("Invalid parameters: ", entry);
+          }
+        }
+
+        topic_partition_.reset(
+            RdKafka::TopicPartition::create(topic, partition, offset));
+
+        offset_ = topic_partition_->offset();
+        limit_ = -1;
+        if (parts.size() > 3) {
+          if (!strings::safe_strto64(parts[3], &limit_)) {
+            return errors::InvalidArgument("Invalid parameters: ", entry);
+          }
+        }
+
+        std::unique_ptr<RdKafka::Conf> conf(
+            RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL));
+        std::unique_ptr<RdKafka::Conf> topic_conf(
+            RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC));
+
+        std::string errstr;
+
+        RdKafka::Conf::ConfResult result =
+            conf->set("default_topic_conf", topic_conf.get(), errstr);
+        if (result != RdKafka::Conf::CONF_OK) {
+          return errors::Internal("Failed to set default_topic_conf:", errstr);
+        }
+
+        result = conf->set("bootstrap.servers", dataset()->servers_, errstr);
+        if (result != RdKafka::Conf::CONF_OK) {
+          return errors::Internal("Failed to set bootstrap.servers ",
+                                  dataset()->servers_, ":", errstr);
+        }
+        result = conf->set("group.id", dataset()->group_, errstr);
+        if (result != RdKafka::Conf::CONF_OK) {
+          return errors::Internal("Failed to set group.id ", dataset()->group_,
+                                  ":", errstr);
+        }
+
+        consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr));
+        if (!consumer_.get()) {
+          return errors::Internal("Failed to create consumer:", errstr);
+        }
+
+        std::vector<RdKafka::TopicPartition*> partitions;
+        partitions.emplace_back(topic_partition_.get());
+        RdKafka::ErrorCode err = consumer_->assign(partitions);
+        if (err != RdKafka::ERR_NO_ERROR) {
+          return errors::Internal(
+              "Failed to assign partition [", topic_partition_->topic(), ", ",
+              topic_partition_->partition(), ", ", topic_partition_->offset(),
+              "]:", RdKafka::err2str(err));
+        }
+
+        return Status::OK();
+      }
+
+      // Resets all Kafka streams.
+      void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        consumer_->unassign();
+        consumer_->close();
+        consumer_.reset(nullptr);
+      }
+
+      mutex mu_;
+      size_t current_topic_index_ GUARDED_BY(mu_) = 0;
+      int64 offset_ GUARDED_BY(mu_) = 0;
+      int64 limit_ GUARDED_BY(mu_) = -1;
+      std::unique_ptr<RdKafka::TopicPartition> topic_partition_ GUARDED_BY(mu_);
+      std::unique_ptr<RdKafka::KafkaConsumer> consumer_ GUARDED_BY(mu_);
+    };
+
+    const std::vector<string> topics_;
+    const std::string servers_;
+    const std::string group_;
+    const bool eof_;
+    const int64 timeout_;
+  };
+};
+
+REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU),
+                        KafkaDatasetOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc
new file mode 100644
index 0000000..8cdf161
--- /dev/null
+++ b/tensorflow/contrib/kafka/ops/kafka_ops.cc
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KafkaDataset")
+    .Input("topics: string")
+    .Input("servers: string")
+    .Input("group: string")
+    .Input("eof: bool")
+    .Input("timeout: int64")
+    .Output("handle: variant")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Creates a dataset that emits the messages of one or more Kafka topics.
+
+topics: A `tf.string` tensor containing one or more subscriptions,
+  in the format of [topic:partition:offset:length],
+  by default length is -1 for unlimited.
+servers: A list of bootstrap servers.
+group: The consumer group id.
+eof: If True, the kafka reader will stop on EOF.
+timeout: The timeout value for the Kafka Consumer to wait
+  (in millisecond).
+)doc");
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
new file mode 100644
index 0000000..94cf6b5
--- /dev/null
+++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
@@ -0,0 +1,117 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License.  You may obtain a copy of
+# the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for KafkaDataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import os
+
+from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+class KafkaDatasetTest(test.TestCase):
+
+  def setUp(self):
+    # The Kafka server has to be setup before the test
+    # and tear down after the test manually.
+    # The docker engine has to be installed.
+    #
+    # To setup the Kafka server:
+    # $ bash kafka_test.sh start kafka
+    #
+    # To team down the Kafka server:
+    # $ bash kafka_test.sh stop kafka
+    pass
+
+  def testKafkaDataset(self):
+    topics = array_ops.placeholder(dtypes.string, shape=[None])
+    num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
+    batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+    repeat_dataset = kafka_dataset_ops.KafkaDataset(
+        topics, group="test", eof=True).repeat(num_epochs)
+    batch_dataset = repeat_dataset.batch(batch_size)
+
+    iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+    init_op = iterator.make_initializer(repeat_dataset)
+    init_batch_op = iterator.make_initializer(batch_dataset)
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      # Basic test: read from topic 0.
+      sess.run(
+          init_op, feed_dict={topics: ["test:0:0:4"],
+                              num_epochs: 1})
+      for i in range(5):
+        self.assertEqual("D"+str(i), sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+      # Basic test: read from topic 1.
+      sess.run(
+          init_op, feed_dict={topics: ["test:0:5:-1"],
+                              num_epochs: 1})
+      for i in range(5):
+        self.assertEqual("D"+str(i + 5), sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+      # Basic test: read from both topics.
+      sess.run(init_op, feed_dict={topics: ["test:0:0:4", "test:0:5:-1"],
+                                   num_epochs: 1})
+      for j in range(2):
+        for i in range(5):
+          self.assertEqual("D"+str(i + j * 5), sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+      # Test repeated iteration through both files.
+      sess.run(init_op, feed_dict={topics: ["test:0:0:4", "test:0:5:-1"],
+                                   num_epochs: 10})
+      for _ in range(10):
+        for j in range(2):
+          for i in range(5):
+            self.assertEqual("D"+str(i + j * 5), sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+      # Test batched and repeated iteration through both files.
+      sess.run(
+          init_batch_op,
+          feed_dict={topics: ["test:0:0:4", "test:0:5:-1"],
+                     num_epochs: 10,
+                     batch_size: 5})
+      for _ in range(10):
+        self.assertAllEqual(["D"+str(i) for i in range(5)],
+                            sess.run(get_next))
+        self.assertAllEqual(["D"+str(i + 5) for i in range(5)],
+                            sess.run(get_next))
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh
new file mode 100644
index 0000000..7997c12
--- /dev/null
+++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh
@@ -0,0 +1,34 @@
+#!/usr/bin/env bash
+
+set -e
+set -o pipefail
+
+if [ "$#" -ne 2 ]; then
+  echo "Usage: $0 start|stop <kafka container name>" >&2
+  exit 1
+fi
+
+container=$2
+if [ "$1" == "start" ]; then
+    docker run -d --rm --net=host --name=$container spotify/kafka
+    echo Wait 5 secs until kafka is up and running
+    sleep 5
+    echo Create test topic
+    docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic test'
+    echo Create test message
+    docker exec $container bash -c 'echo -e "D0\nD1\nD2\nD3\nD4\nD5\nD6\nD7\nD8\nD9" > /test'
+    echo Produce test message
+    docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-console-producer.sh --topic test --broker-list 127.0.0.1:9092 < /test'
+
+    echo Container $container started successfully
+elif [ "$1" == "stop" ]; then
+    docker rm -f $container
+
+    echo Container $container stopped successfully
+else
+  echo "Usage: $0 start|stop <kafka container name>" >&2
+  exit 1
+fi
+
+
+
diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
new file mode 100644
index 0000000..6590d86
--- /dev/null
+++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kafka Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kafka.python.ops import gen_kafka_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.data.ops.readers import Dataset
+from tensorflow.python.framework import common_shapes
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+
+class KafkaDataset(Dataset):
+  """A Kafka Dataset that consumes the message.
+  """
+
+  def __init__(self, topics, servers="localhost", group="", eof=False, timeout=1000):
+    """Create a KafkaReader.
+
+    Args:
+      topics: A `tf.string` tensor containing one or more subscriptions,
+              in the format of [topic:partition:offset:length],
+              by default length is -1 for unlimited.
+      servers: A list of bootstrap servers.
+      group: The consumer group id.
+      eof: If True, the kafka reader will stop on EOF.
+      timeout: The timeout value for the Kafka Consumer to wait
+               (in millisecond).
+    """
+    super(KafkaDataset, self).__init__()
+    self._topics = ops.convert_to_tensor(
+        topics, dtype=dtypes.string, name="topics")
+    self._servers = ops.convert_to_tensor(
+        servers, dtype=dtypes.string, name="servers")
+    self._group = ops.convert_to_tensor(
+        group, dtype=dtypes.string, name="group")
+    self._eof = ops.convert_to_tensor(
+        eof, dtype=dtypes.bool, name="eof")
+    self._timeout = ops.convert_to_tensor(
+        timeout, dtype=dtypes.int64, name="timeout")
+
+  def _as_variant_tensor(self):
+    return gen_kafka_ops.kafka_dataset(
+        self._topics, self._servers, self._group, self._eof, self._timeout)
+
+  @property
+  def output_classes(self):
+    return ops.Tensor
+
+  @property
+  def output_shapes(self):
+    return tensor_shape.scalar()
+
+  @property
+  def output_types(self):
+    return dtypes.string
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 2102c5c..119ffa3 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -489,6 +489,12 @@
           "//tensorflow/core/platform/s3:s3_file_system",
       ],
       "//conditions:default": [],
+  }) + select({
+      "//tensorflow:with_kafka_support": [
+          "//tensorflow/contrib/kafka:kafka_kernels",
+          "//tensorflow/contrib/kafka:kafka_ops_op_lib",
+      ],
+      "//conditions:default": [],
   })
 
 # TODO(jart, jhseu): Delete when GCP is default on.
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index f7d9075..f9c13e5 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -561,6 +561,18 @@
   )
 
   tf_http_archive(
+      name = "kafka",
+      urls = [
+          "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz",
+          "https://github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz",
+      ],
+      sha256 = "dd035d57c8f19b0b612dd6eefe6e5eebad76f506e302cccb7c2066f25a83585e",
+      strip_prefix = "librdkafka-0.11.1",
+      build_file = str(Label("//third_party:kafka/BUILD")),
+      patch_file = str(Label("//third_party/kafka:config.patch")),
+  )
+
+  tf_http_archive(
       name = "aws",
       urls = [
           "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
diff --git a/third_party/kafka/BUILD b/third_party/kafka/BUILD
new file mode 100644
index 0000000..a61a9e1
--- /dev/null
+++ b/third_party/kafka/BUILD
@@ -0,0 +1,147 @@
+# Description:
+#   Kafka C/C++ (librdkafka) client library
+
+licenses(["notice"])  # 2-clause BSD license
+
+exports_files(["LICENSE"])
+
+cc_library(
+    name = "kafka",
+    srcs = [
+        "config.h",
+        "src-cpp/ConfImpl.cpp",
+        "src-cpp/ConsumerImpl.cpp",
+        "src-cpp/HandleImpl.cpp",
+        "src-cpp/KafkaConsumerImpl.cpp",
+        "src-cpp/MessageImpl.cpp",
+        "src-cpp/MetadataImpl.cpp",
+        "src-cpp/QueueImpl.cpp",
+        "src-cpp/RdKafka.cpp",
+        "src-cpp/TopicImpl.cpp",
+        "src-cpp/TopicPartitionImpl.cpp",
+        "src/crc32c.c",
+        "src/crc32c.h",
+        "src/lz4.c",
+        "src/lz4.h",
+        "src/lz4frame.c",
+        "src/lz4frame.h",
+        "src/lz4frame_static.h",
+        "src/lz4hc.c",
+        "src/lz4hc.h",
+        "src/lz4opt.h",
+        "src/queue.h",
+        "src/rd.h",
+        "src/rdaddr.c",
+        "src/rdaddr.h",
+        "src/rdatomic.h",
+        "src/rdavg.h",
+        "src/rdavl.c",
+        "src/rdavl.h",
+        "src/rdbuf.c",
+        "src/rdbuf.h",
+        "src/rdcrc32.h",
+        "src/rddl.h",
+        "src/rdendian.h",
+        "src/rdgz.c",
+        "src/rdgz.h",
+        "src/rdinterval.h",
+        "src/rdkafka.c",
+        "src/rdkafka.h",
+        "src/rdkafka_assignor.c",
+        "src/rdkafka_assignor.h",
+        "src/rdkafka_broker.c",
+        "src/rdkafka_broker.h",
+        "src/rdkafka_buf.c",
+        "src/rdkafka_buf.h",
+        "src/rdkafka_cgrp.c",
+        "src/rdkafka_cgrp.h",
+        "src/rdkafka_conf.c",
+        "src/rdkafka_conf.h",
+        "src/rdkafka_event.h",
+        "src/rdkafka_feature.c",
+        "src/rdkafka_feature.h",
+        "src/rdkafka_int.h",
+        "src/rdkafka_interceptor.c",
+        "src/rdkafka_interceptor.h",
+        "src/rdkafka_lz4.c",
+        "src/rdkafka_lz4.h",
+        "src/rdkafka_metadata.c",
+        "src/rdkafka_metadata.h",
+        "src/rdkafka_metadata_cache.c",
+        "src/rdkafka_msg.c",
+        "src/rdkafka_msg.h",
+        "src/rdkafka_msgset.h",
+        "src/rdkafka_msgset_reader.c",
+        "src/rdkafka_msgset_writer.c",
+        "src/rdkafka_offset.c",
+        "src/rdkafka_offset.h",
+        "src/rdkafka_op.c",
+        "src/rdkafka_op.h",
+        "src/rdkafka_partition.c",
+        "src/rdkafka_partition.h",
+        "src/rdkafka_pattern.c",
+        "src/rdkafka_pattern.h",
+        "src/rdkafka_proto.h",
+        "src/rdkafka_queue.c",
+        "src/rdkafka_queue.h",
+        "src/rdkafka_range_assignor.c",
+        "src/rdkafka_request.c",
+        "src/rdkafka_request.h",
+        "src/rdkafka_roundrobin_assignor.c",
+        "src/rdkafka_sasl.c",
+        "src/rdkafka_sasl.h",
+        "src/rdkafka_sasl_int.h",
+        "src/rdkafka_sasl_plain.c",
+        "src/rdkafka_subscription.c",
+        "src/rdkafka_subscription.h",
+        "src/rdkafka_timer.c",
+        "src/rdkafka_timer.h",
+        "src/rdkafka_topic.c",
+        "src/rdkafka_topic.h",
+        "src/rdkafka_transport.c",
+        "src/rdkafka_transport.h",
+        "src/rdkafka_transport_int.h",
+        "src/rdlist.c",
+        "src/rdlist.h",
+        "src/rdlog.c",
+        "src/rdlog.h",
+        "src/rdports.c",
+        "src/rdports.h",
+        "src/rdposix.h",
+        "src/rdrand.c",
+        "src/rdrand.h",
+        "src/rdregex.c",
+        "src/rdregex.h",
+        "src/rdstring.c",
+        "src/rdstring.h",
+        "src/rdsysqueue.h",
+        "src/rdtime.h",
+        "src/rdtypes.h",
+        "src/rdunittest.c",
+        "src/rdunittest.h",
+        "src/rdvarint.c",
+        "src/rdvarint.h",
+        "src/snappy.c",
+        "src/snappy.h",
+        "src/tinycthread.c",
+        "src/tinycthread.h",
+        "src/xxhash.c",
+        "src/xxhash.h",
+    ],
+    hdrs = [
+        "config.h",
+    ],
+    defines = [
+    ],
+    includes = [
+        "src",
+        "src-cpp",
+    ],
+    linkopts = [
+        "-lpthread",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        "@boringssl//:ssl",
+    ],
+)
diff --git a/third_party/kafka/config.patch b/third_party/kafka/config.patch
new file mode 100644
index 0000000..fa5c2d3
--- /dev/null
+++ b/third_party/kafka/config.patch
@@ -0,0 +1,44 @@
+diff -Naur a/config.h b/config.h
+--- a/config.h	1970-01-01 00:00:00.000000000 +0000
++++ b/config.h	2017-10-28 00:57:03.316957390 +0000
+@@ -0,0 +1,40 @@
++#pragma once
++#define WITHOUT_OPTIMIZATION 0
++#define ENABLE_DEVEL 0
++#define ENABLE_REFCNT_DEBUG 0
++#define ENABLE_SHAREDPTR_DEBUG 0
++
++#define HAVE_ATOMICS_32 1
++#define HAVE_ATOMICS_32_SYNC 1
++
++#if (HAVE_ATOMICS_32)
++# if (HAVE_ATOMICS_32_SYNC)
++#  define ATOMIC_OP32(OP1,OP2,PTR,VAL) __sync_ ## OP1 ## _and_ ## OP2(PTR, VAL)
++# else
++#  define ATOMIC_OP32(OP1,OP2,PTR,VAL) __atomic_ ## OP1 ## _ ## OP2(PTR, VAL, __ATOMIC_SEQ_CST)
++# endif
++#endif
++
++#define HAVE_ATOMICS_64 1
++#define HAVE_ATOMICS_64_SYNC 1
++
++#if (HAVE_ATOMICS_64)
++# if (HAVE_ATOMICS_64_SYNC)
++#  define ATOMIC_OP64(OP1,OP2,PTR,VAL) __sync_ ## OP1 ## _and_ ## OP2(PTR, VAL)
++# else
++#  define ATOMIC_OP64(OP1,OP2,PTR,VAL) __atomic_ ## OP1 ## _ ## OP2(PTR, VAL, __ATOMIC_SEQ_CST)
++# endif
++#endif
++
++
++#define WITH_ZLIB 1
++#define WITH_LIBDL 1
++#define WITH_PLUGINS 0
++#define WITH_SNAPPY 1
++#define WITH_SOCKEM 1
++#define WITH_SSL 1
++#define WITH_SASL 0
++#define WITH_SASL_SCRAM 0
++#define WITH_SASL_CYRUS 0
++#define HAVE_REGEX 1
++#define HAVE_STRNDUP 1