Initial API for tf.contrib.distributions.
Change: 115725802
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index dda77f7..fafc6e8 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -13,6 +13,7 @@
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/ctc:ctc_py",
+ "//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/testing:testing_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 269d439..f9290ed 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -21,6 +21,7 @@
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import ctc
+from tensorflow.contrib import distributions
from tensorflow.contrib import layers
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import testing
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
new file mode 100644
index 0000000..a5fde45
--- /dev/null
+++ b/tensorflow/contrib/distributions/BUILD
@@ -0,0 +1,49 @@
+# Description:
+# Contains ops to train linear models on top of TensorFlow.
+# APIs here are meant to evolve over time.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+
+py_library(
+ name = "distributions_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ srcs_version = "PY2AND3",
+)
+
+cuda_py_tests(
+ name = "gaussian_test",
+ srcs = ["python/kernel_tests/gaussian_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_tests(
+ name = "gaussian_conjugate_posteriors_test",
+ srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
new file mode 100644
index 0000000..46aae25
--- /dev/null
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for representing statistical distributions.
+
+## This package provides classes for statistical distributions.
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
+from tensorflow.contrib.distributions.python.ops.gaussian import *
+# from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long
+# from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * # pylint: disable=line-too-long
diff --git a/tensorflow/contrib/distributions/python/__init__.py b/tensorflow/contrib/distributions/python/__init__.py
new file mode 100644
index 0000000..c9b177d
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py
new file mode 100644
index 0000000..115f56f
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py
@@ -0,0 +1,65 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for initializers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import tensorflow as tf
+
+gaussian_conjugate_posteriors = tf.contrib.distributions.gaussian_conjugate_posteriors # pylint: disable=line-too-long
+
+
+class GaussianTest(tf.test.TestCase):
+
+ def testGaussianConjugateKnownSigmaPosterior(self):
+ with tf.Session():
+ mu0 = tf.constant(3.0)
+ sigma0 = tf.constant(math.sqrt(1/0.1))
+ sigma = tf.constant(math.sqrt(1/0.5))
+ x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
+ s = tf.reduce_sum(x)
+ n = tf.size(x)
+ prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
+ posterior = gaussian_conjugate_posteriors.known_sigma_posterior(
+ prior=prior, sigma=sigma, s=s, n=n)
+
+ # Smoke test
+ self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian))
+ posterior_log_pdf = posterior.log_pdf(x).eval()
+ self.assertEqual(posterior_log_pdf.shape, (6,))
+
+ def testGaussianConjugateKnownSigmaPredictive(self):
+ with tf.Session():
+ mu0 = tf.constant(3.0)
+ sigma0 = tf.constant(math.sqrt(1/0.1))
+ sigma = tf.constant(math.sqrt(1/0.5))
+ x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
+ s = tf.reduce_sum(x)
+ n = tf.size(x)
+ prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
+ predictive = gaussian_conjugate_posteriors.known_sigma_predictive(
+ prior=prior, sigma=sigma, s=s, n=n)
+
+ # Smoke test
+ self.assertTrue(isinstance(predictive, tf.contrib.distributions.Gaussian))
+ predictive_log_pdf = predictive.log_pdf(x).eval()
+ self.assertEqual(predictive_log_pdf.shape, (6,))
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gaussian_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_test.py
new file mode 100644
index 0000000..c20cb6d
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_test.py
@@ -0,0 +1,77 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for initializers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+import tensorflow as tf
+
+
+class GaussianTest(tf.test.TestCase):
+
+ def testGaussianLogLikelihoodPDF(self):
+ with tf.Session():
+ mu = tf.constant(3.0)
+ sigma = tf.constant(math.sqrt(1/0.1))
+ mu_v = 3.0
+ sigma_v = np.sqrt(1/0.1)
+ x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
+ gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
+ expected_log_pdf = np.log(
+ 1/np.sqrt(2*np.pi)/sigma_v*np.exp(-1.0/(2*sigma_v**2)*(x-mu_v)**2))
+
+ log_pdf = gaussian.log_pdf(x)
+ self.assertAllClose(expected_log_pdf, log_pdf.eval())
+
+ pdf = gaussian.pdf(x)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
+
+ def testGaussianCDF(self):
+ with tf.Session():
+ mu = tf.constant(3.0)
+ sigma = tf.constant(math.sqrt(1/0.1))
+ mu_v = 3.0
+ sigma_v = np.sqrt(1/0.1)
+ x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
+ gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
+ erf_fn = np.vectorize(math.erf)
+
+ # From Wikipedia
+ expected_cdf = 0.5*(1.0 + erf_fn((x - mu_v)/(sigma_v*np.sqrt(2))))
+
+ cdf = gaussian.cdf(x)
+ self.assertAllClose(expected_cdf, cdf.eval())
+
+ def testGaussianSample(self):
+ with tf.Session():
+ mu = tf.constant(3.0)
+ sigma = tf.constant(math.sqrt(1/0.1))
+ mu_v = 3.0
+ sigma_v = np.sqrt(1/0.1)
+ n = tf.constant(10000)
+ gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
+ samples = gaussian.sample(n, seed=137)
+ sample_values = samples.eval()
+ self.assertEqual(sample_values.shape, (10000,))
+ self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2)
+ self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/gaussian.py b/tensorflow/contrib/distributions/python/ops/gaussian.py
new file mode 100644
index 0000000..d54825c
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/gaussian.py
@@ -0,0 +1,123 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Normal (Gaussian) distribution class.
+
+@@Gaussian
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+
+class Gaussian(object):
+ """The Normal (Gaussian) distribution with mean mu and stddev sigma.
+
+ The PDF of this distribution is:
+ f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2))
+ """
+
+ def __init__(self, mu, sigma):
+ """Construct a new Gaussian distribution with mean mu and stddev sigma.
+
+ Args:
+ mu: Scalar tensor, the mean of the distribution.
+ sigma: Scalar tensor, the precision of the distribution.
+
+ Raises:
+ TypeError: if mu and sigma are different dtypes.
+ """
+ self._mu = ops.convert_to_tensor(mu)
+ self._sigma = ops.convert_to_tensor(sigma)
+ if mu.dtype != sigma.dtype:
+ raise TypeError("Expected same dtype for mu, sigma but got: %s vs. %s"
+ % (mu.dtype, sigma.dtype))
+
+ @property
+ def dtype(self):
+ return self._mu.dtype
+
+ @property
+ def shape(self):
+ return constant_op.constant([]) # Scalar
+
+ @property
+ def mu(self):
+ return self._mu
+
+ @property
+ def sigma(self):
+ return self._sigma
+
+ def log_pdf(self, x):
+ """Log likelihood of observations in x under Gaussian with mu and sigma.
+
+ Args:
+ x: 1-D, a vector of observations.
+
+ Returns:
+ log_lik: 1-D, a vector of log likelihoods of `x` under the model.
+ """
+ return (-0.5*math.log(2 * math.pi) - math_ops.log(self._sigma)
+ -0.5*math_ops.square((x - self._mu) / self._sigma))
+
+ def cdf(self, x):
+ """CDF of observations in x under Gaussian with mu and sigma.
+
+ Args:
+ x: 1-D, a vector of observations.
+
+ Returns:
+ cdf: 1-D, a vector of CDFs of `x` under the model.
+ """
+ return (0.5 + 0.5*math_ops.erf(
+ 1.0/(math.sqrt(2.0) * self._sigma)*(x - self._mu)))
+
+ def log_cdf(self, x):
+ """Log of the CDF of observations x under Gaussian with mu and sigma."""
+ return math_ops.log(self.cdf(x))
+
+ def pdf(self, x):
+ """The PDF for observations x.
+
+ Args:
+ x: 1-D, a vector of observations.
+
+ Returns:
+ pdf: 1-D, a vector of pdf values of `x` under the model.
+ """
+ return math_ops.exp(self.log_pdf(x))
+
+ def sample(self, n, seed=None):
+ """Sample `n` observations from this Distribution.
+
+ Args:
+ n: Scalar int `Tensor`, the number of observations to sample.
+ seed: Python integer, the random seed.
+
+ Returns:
+ samples: A vector of samples with shape `[n]`.
+ """
+ return random_ops.random_normal(
+ shape=array_ops.expand_dims(n, 0), mean=self._mu,
+ stddev=self._sigma, dtype=self._mu.dtype, seed=seed)
diff --git a/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py b/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py
new file mode 100644
index 0000000..cd59a09
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py
@@ -0,0 +1,126 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Gaussian distribution: conjugate posterior closed form calculations.
+
+@@known_sigma_posterior
+@@known_sigma_predictive
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops.gaussian import Gaussian # pylint: disable=line-too-long
+
+from tensorflow.python.ops import math_ops
+
+
+def known_sigma_posterior(prior, sigma, s, n):
+ """Return the conjugate posterior distribution with known sigma.
+
+ Accepts a prior Gaussian distribution, having parameters `mu0` and `sigma0`,
+ a known `sigma` of the predictive distribution (also assumed Gaussian),
+ and statistical estimates `s` (the sum of the observations) and
+ `n` (the number of observations).
+
+ Returns a posterior (also Gaussian) distribution object, with parameters
+ `(mu', sigma'^2)`, where:
+ ```
+ sigma'^2 = 1/(1/sigma0^2 + n/sigma^2),
+ mu' = (mu0/sigma0^2 + s/sigma^2) * sigma'^2.
+ ```
+
+ Args:
+ prior: `Normal` object of type `dtype`, the prior distribution having
+ parameters `(mu0, sigma0)`.
+ sigma: Scalar of type `dtype`, `sigma > 0`. The known stddev parameter.
+ s: Scalar, of type `dtype`, the sum of observations.
+ n: Scalar int, the number of observations.
+
+ Returns:
+ A new Gaussian posterior distribution.
+
+ Raises:
+ TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
+ Gaussian object.
+ """
+ if not isinstance(prior, Gaussian):
+ raise TypeError("Expected prior to be an instance of type Gaussian")
+
+ if s.dtype != prior.dtype:
+ raise TypeError(
+ "Observation sum s.dtype does not match prior dtype: %s vs. %s"
+ % (s.dtype, prior.dtype))
+
+ n = math_ops.cast(n, prior.dtype)
+ sigma0_2 = math_ops.square(prior.sigma)
+ sigma_2 = math_ops.square(sigma)
+ sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
+ return Gaussian(
+ mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
+ sigma=math_ops.sqrt(sigmap_2))
+
+
+def known_sigma_predictive(prior, sigma, s, n):
+ """Return the posterior predictive distribution with known sigma.
+
+ Accepts a prior Gaussian distribution, having parameters `mu0` and `sigma0`,
+ a known `sigma` of the predictive distribution (also assumed Gaussian),
+ and statistical estimates `s` (the sum of the observations) and
+ `n` (the number of observations).
+
+ Calculates the Gaussian distribution p(x | sigma):
+ ```
+ p(x | sigma) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu
+ = N(x | prior.mu, 1/(sigma^2 + prior.sigma^2))
+ ```
+
+ Returns the predictive posterior distribution object, with parameters
+ `(mu', sigma'^2)`, where:
+ ```
+ sigma_n^2 = 1/(1/sigma0^2 + n/sigma^2),
+ mu' = (mu0/sigma0^2 + s/sigma^2) * sigma_n^2.
+ sigma'^2 = sigma_n^2 + sigma^2,
+ ```
+
+ Args:
+ prior: `Normal` object of type `dtype`, the prior distribution having
+ parameters `(mu0, sigma0)`.
+ sigma: Scalar of type `dtype`, `sigma > 0`. The known stddev parameter.
+ s: Scalar, of type `dtype`, the sum of observations.
+ n: Scalar int, the number of observations.
+
+ Returns:
+ A new Gaussian posterior distribution.
+
+ Raises:
+ TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
+ Gaussian object.
+ """
+ if not isinstance(prior, Gaussian):
+ raise TypeError("Expected prior to be an instance of type Gaussian")
+
+ if s.dtype != prior.dtype:
+ raise TypeError(
+ "Observation sum s.dtype does not match prior dtype: %s vs. %s"
+ % (s.dtype, prior.dtype))
+
+ n = math_ops.cast(n, prior.dtype)
+ sigma0_2 = math_ops.square(prior.sigma)
+ sigma_2 = math_ops.square(sigma)
+ sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
+ return Gaussian(
+ mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
+ sigma=math_ops.sqrt(sigmap_2 + sigma_2))