sctp: fix the issue that flags are ignored when using kernel_connect

commit 644fbdeacf1d3edd366e44b8ba214de9d1dd66a9 upstream.

Now sctp uses inet_dgram_connect as its proto_ops .connect, and the flags
param can't be passed into its proto .connect where this flags is really
needed.

sctp works around it by getting flags from socket file in __sctp_connect.
It works for connecting from userspace, as inherently the user sock has
socket file and it passes f_flags as the flags param into the proto_ops
.connect.

However, the sock created by sock_create_kern doesn't have a socket file,
and it passes the flags (like O_NONBLOCK) by using the flags param in
kernel_connect, which calls proto_ops .connect later.

So to fix it, this patch defines a new proto_ops .connect for sctp,
sctp_inet_connect, which calls __sctp_connect() directly with this
flags param. After this, the sctp's proto .connect can be removed.

Note that sctp_inet_connect doesn't need to do some checks that are not
needed for sctp, which makes thing better than with inet_dgram_connect.

Suggested-by: Marcelo Ricardo Leitner <marcelo.leitner@gmail.com>
Signed-off-by: Xin Long <lucien.xin@gmail.com>
Acked-by: Neil Horman <nhorman@tuxdriver.com>
Acked-by: Marcelo Ricardo Leitner <marcelo.leitner@gmail.com>
Reviewed-by: Michal Kubecek <mkubecek@suse.cz>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>

diff --git a/net/sctp/ipv6.c b/net/sctp/ipv6.c
index 31f461f..824ebbf 100644
--- a/net/sctp/ipv6.c
+++ b/net/sctp/ipv6.c
@@ -973,7 +973,7 @@
 	.owner		   = THIS_MODULE,
 	.release	   = inet6_release,
 	.bind		   = inet6_bind,
-	.connect	   = inet_dgram_connect,
+	.connect	   = sctp_inet_connect,
 	.socketpair	   = sock_no_socketpair,
 	.accept		   = inet_accept,
 	.getname	   = sctp_getname,
diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c
index 833283c..9cb06ca 100644
--- a/net/sctp/protocol.c
+++ b/net/sctp/protocol.c
@@ -1014,7 +1014,7 @@
 	.owner		   = THIS_MODULE,
 	.release	   = inet_release,	/* Needs to be wrapped... */
 	.bind		   = inet_bind,
-	.connect	   = inet_dgram_connect,
+	.connect	   = sctp_inet_connect,
 	.socketpair	   = sock_no_socketpair,
 	.accept		   = inet_accept,
 	.getname	   = inet_getname,	/* Semantics are different.  */
diff --git a/net/sctp/socket.c b/net/sctp/socket.c
index 574a6a2..5794ca5 100644
--- a/net/sctp/socket.c
+++ b/net/sctp/socket.c
@@ -1074,7 +1074,7 @@
  */
 static int __sctp_connect(struct sock *sk,
 			  struct sockaddr *kaddrs,
-			  int addrs_size,
+			  int addrs_size, int flags,
 			  sctp_assoc_t *assoc_id)
 {
 	struct net *net = sock_net(sk);
@@ -1092,7 +1092,6 @@
 	union sctp_addr *sa_addr = NULL;
 	void *addr_buf;
 	unsigned short port;
-	unsigned int f_flags = 0;
 
 	sp = sctp_sk(sk);
 	ep = sp->ep;
@@ -1240,13 +1239,7 @@
 	sp->pf->to_sk_daddr(sa_addr, sk);
 	sk->sk_err = 0;
 
-	/* in-kernel sockets don't generally have a file allocated to them
-	 * if all they do is call sock_create_kern().
-	 */
-	if (sk->sk_socket->file)
-		f_flags = sk->sk_socket->file->f_flags;
-
-	timeo = sock_sndtimeo(sk, f_flags & O_NONBLOCK);
+	timeo = sock_sndtimeo(sk, flags & O_NONBLOCK);
 
 	if (assoc_id)
 		*assoc_id = asoc->assoc_id;
@@ -1341,7 +1334,7 @@
 {
 	struct sockaddr *kaddrs;
 	gfp_t gfp = GFP_KERNEL;
-	int err = 0;
+	int err = 0, flags = 0;
 
 	pr_debug("%s: sk:%p addrs:%p addrs_size:%d\n",
 		 __func__, sk, addrs, addrs_size);
@@ -1361,11 +1354,18 @@
 		return -ENOMEM;
 
 	if (__copy_from_user(kaddrs, addrs, addrs_size)) {
-		err = -EFAULT;
-	} else {
-		err = __sctp_connect(sk, kaddrs, addrs_size, assoc_id);
+		kfree(kaddrs);
+		return -EFAULT;
 	}
 
+	/* in-kernel sockets don't generally have a file allocated to them
+	 * if all they do is call sock_create_kern().
+	 */
+	if (sk->sk_socket->file)
+		flags = sk->sk_socket->file->f_flags;
+
+	err = __sctp_connect(sk, kaddrs, addrs_size, flags, assoc_id);
+
 	kfree(kaddrs);
 
 	return err;
@@ -3979,16 +3979,26 @@
  * len: the size of the address.
  */
 static int sctp_connect(struct sock *sk, struct sockaddr *addr,
-			int addr_len)
+			int addr_len, int flags)
 {
-	int err = 0;
+	struct inet_sock *inet = inet_sk(sk);
 	struct sctp_af *af;
+	int err = 0;
 
 	lock_sock(sk);
 
 	pr_debug("%s: sk:%p, sockaddr:%p, addr_len:%d\n", __func__, sk,
 		 addr, addr_len);
 
+	/* We may need to bind the socket. */
+	if (!inet->inet_num) {
+		if (sk->sk_prot->get_port(sk, 0)) {
+			release_sock(sk);
+			return -EAGAIN;
+		}
+		inet->inet_sport = htons(inet->inet_num);
+	}
+
 	/* Validate addr_len before calling common connect/connectx routine. */
 	af = sctp_get_af_specific(addr->sa_family);
 	if (!af || addr_len < af->sockaddr_len) {
@@ -3997,13 +4007,25 @@
 		/* Pass correct addr len to common routine (so it knows there
 		 * is only one address being passed.
 		 */
-		err = __sctp_connect(sk, addr, af->sockaddr_len, NULL);
+		err = __sctp_connect(sk, addr, af->sockaddr_len, flags, NULL);
 	}
 
 	release_sock(sk);
 	return err;
 }
 
+int sctp_inet_connect(struct socket *sock, struct sockaddr *uaddr,
+		      int addr_len, int flags)
+{
+	if (addr_len < sizeof(uaddr->sa_family))
+		return -EINVAL;
+
+	if (uaddr->sa_family == AF_UNSPEC)
+		return -EOPNOTSUPP;
+
+	return sctp_connect(sock->sk, uaddr, addr_len, flags);
+}
+
 /* FIXME: Write comments. */
 static int sctp_disconnect(struct sock *sk, int flags)
 {
@@ -7896,7 +7918,6 @@
 	.name        =	"SCTP",
 	.owner       =	THIS_MODULE,
 	.close       =	sctp_close,
-	.connect     =	sctp_connect,
 	.disconnect  =	sctp_disconnect,
 	.accept      =	sctp_accept,
 	.ioctl       =	sctp_ioctl,
@@ -7935,7 +7956,6 @@
 	.name		= "SCTPv6",
 	.owner		= THIS_MODULE,
 	.close		= sctp_close,
-	.connect	= sctp_connect,
 	.disconnect	= sctp_disconnect,
 	.accept		= sctp_accept,
 	.ioctl		= sctp_ioctl,