shmctl: split the work from copyin/copyout

Signed-off-by: Al Viro <viro@zeniv.linux.org.uk>
diff --git a/ipc/shm.c b/ipc/shm.c
index 28a4448..b4073c0 100644
--- a/ipc/shm.c
+++ b/ipc/shm.c
@@ -813,23 +813,17 @@
  * NOTE: no locks must be held, the rwsem is taken inside this function.
  */
 static int shmctl_down(struct ipc_namespace *ns, int shmid, int cmd,
-		       struct shmid_ds __user *buf, int version)
+		       struct shmid64_ds *shmid64)
 {
 	struct kern_ipc_perm *ipcp;
-	struct shmid64_ds shmid64;
 	struct shmid_kernel *shp;
 	int err;
 
-	if (cmd == IPC_SET) {
-		if (copy_shmid_from_user(&shmid64, buf, version))
-			return -EFAULT;
-	}
-
 	down_write(&shm_ids(ns).rwsem);
 	rcu_read_lock();
 
 	ipcp = ipcctl_pre_down_nolock(ns, &shm_ids(ns), shmid, cmd,
-				      &shmid64.shm_perm, 0);
+				      &shmid64->shm_perm, 0);
 	if (IS_ERR(ipcp)) {
 		err = PTR_ERR(ipcp);
 		goto out_unlock1;
@@ -849,7 +843,7 @@
 		goto out_up;
 	case IPC_SET:
 		ipc_lock_object(&shp->shm_perm);
-		err = ipc_update_perm(&shmid64.shm_perm, ipcp);
+		err = ipc_update_perm(&shmid64->shm_perm, ipcp);
 		if (err)
 			goto out_unlock0;
 		shp->shm_ctim = get_seconds();
@@ -868,125 +862,175 @@
 	return err;
 }
 
-static int shmctl_nolock(struct ipc_namespace *ns, int shmid,
-			 int cmd, int version, void __user *buf)
+static int shmctl_ipc_info(struct ipc_namespace *ns,
+			   struct shminfo64 *shminfo)
 {
-	int err;
-	struct shmid_kernel *shp;
-
-	/* preliminary security checks for *_INFO */
-	if (cmd == IPC_INFO || cmd == SHM_INFO) {
-		err = security_shm_shmctl(NULL, cmd);
-		if (err)
-			return err;
-	}
-
-	switch (cmd) {
-	case IPC_INFO:
-	{
-		struct shminfo64 shminfo;
-
-		memset(&shminfo, 0, sizeof(shminfo));
-		shminfo.shmmni = shminfo.shmseg = ns->shm_ctlmni;
-		shminfo.shmmax = ns->shm_ctlmax;
-		shminfo.shmall = ns->shm_ctlall;
-
-		shminfo.shmmin = SHMMIN;
-		if (copy_shminfo_to_user(buf, &shminfo, version))
-			return -EFAULT;
-
+	int err = security_shm_shmctl(NULL, IPC_INFO);
+	if (!err) {
+		memset(shminfo, 0, sizeof(*shminfo));
+		shminfo->shmmni = shminfo->shmseg = ns->shm_ctlmni;
+		shminfo->shmmax = ns->shm_ctlmax;
+		shminfo->shmall = ns->shm_ctlall;
+		shminfo->shmmin = SHMMIN;
 		down_read(&shm_ids(ns).rwsem);
 		err = ipc_get_maxid(&shm_ids(ns));
 		up_read(&shm_ids(ns).rwsem);
-
 		if (err < 0)
 			err = 0;
-		goto out;
 	}
-	case SHM_INFO:
-	{
-		struct shm_info shm_info;
+	return err;
+}
 
-		memset(&shm_info, 0, sizeof(shm_info));
+static int shmctl_shm_info(struct ipc_namespace *ns,
+			   struct shm_info *shm_info)
+{
+	int err = security_shm_shmctl(NULL, SHM_INFO);
+	if (!err) {
+		memset(shm_info, 0, sizeof(*shm_info));
 		down_read(&shm_ids(ns).rwsem);
-		shm_info.used_ids = shm_ids(ns).in_use;
-		shm_get_stat(ns, &shm_info.shm_rss, &shm_info.shm_swp);
-		shm_info.shm_tot = ns->shm_tot;
-		shm_info.swap_attempts = 0;
-		shm_info.swap_successes = 0;
+		shm_info->used_ids = shm_ids(ns).in_use;
+		shm_get_stat(ns, &shm_info->shm_rss, &shm_info->shm_swp);
+		shm_info->shm_tot = ns->shm_tot;
+		shm_info->swap_attempts = 0;
+		shm_info->swap_successes = 0;
 		err = ipc_get_maxid(&shm_ids(ns));
 		up_read(&shm_ids(ns).rwsem);
-		if (copy_to_user(buf, &shm_info, sizeof(shm_info))) {
-			err = -EFAULT;
-			goto out;
-		}
-
-		err = err < 0 ? 0 : err;
-		goto out;
+		if (err < 0)
+			err = 0;
 	}
-	case SHM_STAT:
-	case IPC_STAT:
-	{
-		struct shmid64_ds tbuf;
-		int result;
+	return err;
+}
 
-		rcu_read_lock();
-		if (cmd == SHM_STAT) {
-			shp = shm_obtain_object(ns, shmid);
-			if (IS_ERR(shp)) {
-				err = PTR_ERR(shp);
-				goto out_unlock;
-			}
-			result = shp->shm_perm.id;
-		} else {
-			shp = shm_obtain_object_check(ns, shmid);
-			if (IS_ERR(shp)) {
-				err = PTR_ERR(shp);
-				goto out_unlock;
-			}
-			result = 0;
-		}
+static int shmctl_stat(struct ipc_namespace *ns, int shmid,
+			int cmd, struct shmid64_ds *tbuf)
+{
+	struct shmid_kernel *shp;
+	int result;
+	int err;
 
-		err = -EACCES;
-		if (ipcperms(ns, &shp->shm_perm, S_IRUGO))
+	rcu_read_lock();
+	if (cmd == SHM_STAT) {
+		shp = shm_obtain_object(ns, shmid);
+		if (IS_ERR(shp)) {
+			err = PTR_ERR(shp);
 			goto out_unlock;
-
-		err = security_shm_shmctl(shp, cmd);
-		if (err)
+		}
+		result = shp->shm_perm.id;
+	} else {
+		shp = shm_obtain_object_check(ns, shmid);
+		if (IS_ERR(shp)) {
+			err = PTR_ERR(shp);
 			goto out_unlock;
-
-		memset(&tbuf, 0, sizeof(tbuf));
-		kernel_to_ipc64_perm(&shp->shm_perm, &tbuf.shm_perm);
-		tbuf.shm_segsz	= shp->shm_segsz;
-		tbuf.shm_atime	= shp->shm_atim;
-		tbuf.shm_dtime	= shp->shm_dtim;
-		tbuf.shm_ctime	= shp->shm_ctim;
-		tbuf.shm_cpid	= shp->shm_cprid;
-		tbuf.shm_lpid	= shp->shm_lprid;
-		tbuf.shm_nattch	= shp->shm_nattch;
-		rcu_read_unlock();
-
-		if (copy_shmid_to_user(buf, &tbuf, version))
-			err = -EFAULT;
-		else
-			err = result;
-		goto out;
+		}
+		result = 0;
 	}
-	default:
-		return -EINVAL;
-	}
+
+	err = -EACCES;
+	if (ipcperms(ns, &shp->shm_perm, S_IRUGO))
+		goto out_unlock;
+
+	err = security_shm_shmctl(shp, cmd);
+	if (err)
+		goto out_unlock;
+
+	memset(tbuf, 0, sizeof(*tbuf));
+	kernel_to_ipc64_perm(&shp->shm_perm, &tbuf->shm_perm);
+	tbuf->shm_segsz	= shp->shm_segsz;
+	tbuf->shm_atime	= shp->shm_atim;
+	tbuf->shm_dtime	= shp->shm_dtim;
+	tbuf->shm_ctime	= shp->shm_ctim;
+	tbuf->shm_cpid	= shp->shm_cprid;
+	tbuf->shm_lpid	= shp->shm_lprid;
+	tbuf->shm_nattch = shp->shm_nattch;
+	rcu_read_unlock();
+	return result;
 
 out_unlock:
 	rcu_read_unlock();
-out:
+	return err;
+}
+
+static int shmctl_do_lock(struct ipc_namespace *ns, int shmid, int cmd)
+{
+	struct shmid_kernel *shp;
+	struct file *shm_file;
+	int err;
+
+	rcu_read_lock();
+	shp = shm_obtain_object_check(ns, shmid);
+	if (IS_ERR(shp)) {
+		err = PTR_ERR(shp);
+		goto out_unlock1;
+	}
+
+	audit_ipc_obj(&(shp->shm_perm));
+	err = security_shm_shmctl(shp, cmd);
+	if (err)
+		goto out_unlock1;
+
+	ipc_lock_object(&shp->shm_perm);
+
+	/* check if shm_destroy() is tearing down shp */
+	if (!ipc_valid_object(&shp->shm_perm)) {
+		err = -EIDRM;
+		goto out_unlock0;
+	}
+
+	if (!ns_capable(ns->user_ns, CAP_IPC_LOCK)) {
+		kuid_t euid = current_euid();
+
+		if (!uid_eq(euid, shp->shm_perm.uid) &&
+		    !uid_eq(euid, shp->shm_perm.cuid)) {
+			err = -EPERM;
+			goto out_unlock0;
+		}
+		if (cmd == SHM_LOCK && !rlimit(RLIMIT_MEMLOCK)) {
+			err = -EPERM;
+			goto out_unlock0;
+		}
+	}
+
+	shm_file = shp->shm_file;
+	if (is_file_hugepages(shm_file))
+		goto out_unlock0;
+
+	if (cmd == SHM_LOCK) {
+		struct user_struct *user = current_user();
+
+		err = shmem_lock(shm_file, 1, user);
+		if (!err && !(shp->shm_perm.mode & SHM_LOCKED)) {
+			shp->shm_perm.mode |= SHM_LOCKED;
+			shp->mlock_user = user;
+		}
+		goto out_unlock0;
+	}
+
+	/* SHM_UNLOCK */
+	if (!(shp->shm_perm.mode & SHM_LOCKED))
+		goto out_unlock0;
+	shmem_lock(shm_file, 0, shp->mlock_user);
+	shp->shm_perm.mode &= ~SHM_LOCKED;
+	shp->mlock_user = NULL;
+	get_file(shm_file);
+	ipc_unlock_object(&shp->shm_perm);
+	rcu_read_unlock();
+	shmem_unlock_mapping(shm_file->f_mapping);
+
+	fput(shm_file);
+	return err;
+
+out_unlock0:
+	ipc_unlock_object(&shp->shm_perm);
+out_unlock1:
+	rcu_read_unlock();
 	return err;
 }
 
 SYSCALL_DEFINE3(shmctl, int, shmid, int, cmd, struct shmid_ds __user *, buf)
 {
-	struct shmid_kernel *shp;
 	int err, version;
 	struct ipc_namespace *ns;
+	struct shmid64_ds tbuf;
 
 	if (cmd < 0 || shmid < 0)
 		return -EINVAL;
@@ -995,91 +1039,44 @@
 	ns = current->nsproxy->ipc_ns;
 
 	switch (cmd) {
-	case IPC_INFO:
-	case SHM_INFO:
-	case SHM_STAT:
-	case IPC_STAT:
-		return shmctl_nolock(ns, shmid, cmd, version, buf);
-	case IPC_RMID:
-	case IPC_SET:
-		return shmctl_down(ns, shmid, cmd, buf, version);
-	case SHM_LOCK:
-	case SHM_UNLOCK:
-	{
-		struct file *shm_file;
-
-		rcu_read_lock();
-		shp = shm_obtain_object_check(ns, shmid);
-		if (IS_ERR(shp)) {
-			err = PTR_ERR(shp);
-			goto out_unlock1;
-		}
-
-		audit_ipc_obj(&(shp->shm_perm));
-		err = security_shm_shmctl(shp, cmd);
-		if (err)
-			goto out_unlock1;
-
-		ipc_lock_object(&shp->shm_perm);
-
-		/* check if shm_destroy() is tearing down shp */
-		if (!ipc_valid_object(&shp->shm_perm)) {
-			err = -EIDRM;
-			goto out_unlock0;
-		}
-
-		if (!ns_capable(ns->user_ns, CAP_IPC_LOCK)) {
-			kuid_t euid = current_euid();
-
-			if (!uid_eq(euid, shp->shm_perm.uid) &&
-			    !uid_eq(euid, shp->shm_perm.cuid)) {
-				err = -EPERM;
-				goto out_unlock0;
-			}
-			if (cmd == SHM_LOCK && !rlimit(RLIMIT_MEMLOCK)) {
-				err = -EPERM;
-				goto out_unlock0;
-			}
-		}
-
-		shm_file = shp->shm_file;
-		if (is_file_hugepages(shm_file))
-			goto out_unlock0;
-
-		if (cmd == SHM_LOCK) {
-			struct user_struct *user = current_user();
-
-			err = shmem_lock(shm_file, 1, user);
-			if (!err && !(shp->shm_perm.mode & SHM_LOCKED)) {
-				shp->shm_perm.mode |= SHM_LOCKED;
-				shp->mlock_user = user;
-			}
-			goto out_unlock0;
-		}
-
-		/* SHM_UNLOCK */
-		if (!(shp->shm_perm.mode & SHM_LOCKED))
-			goto out_unlock0;
-		shmem_lock(shm_file, 0, shp->mlock_user);
-		shp->shm_perm.mode &= ~SHM_LOCKED;
-		shp->mlock_user = NULL;
-		get_file(shm_file);
-		ipc_unlock_object(&shp->shm_perm);
-		rcu_read_unlock();
-		shmem_unlock_mapping(shm_file->f_mapping);
-
-		fput(shm_file);
+	case IPC_INFO: {
+		struct shminfo64 shminfo;
+		err = shmctl_ipc_info(ns, &shminfo);
+		if (err < 0)
+			return err;
+		if (copy_shminfo_to_user(buf, &shminfo, version))
+			err = -EFAULT;
 		return err;
 	}
+	case SHM_INFO: {
+		struct shm_info shm_info;
+		err = shmctl_shm_info(ns, &shm_info);
+		if (err < 0)
+			return err;
+		if (copy_to_user(buf, &shm_info, sizeof(shm_info)))
+			err = -EFAULT;
+		return err;
+	}
+	case SHM_STAT:
+	case IPC_STAT: {
+		err = shmctl_stat(ns, shmid, cmd, &tbuf);
+		if (err < 0)
+			return err;
+		if (copy_shmid_to_user(buf, &tbuf, version))
+			err = -EFAULT;
+		return err;
+	}
+	case IPC_SET:
+		if (copy_shmid_from_user(&tbuf, buf, version))
+			return -EFAULT;
+	case IPC_RMID:
+		return shmctl_down(ns, shmid, cmd, &tbuf);
+	case SHM_LOCK:
+	case SHM_UNLOCK:
+		return shmctl_do_lock(ns, shmid, cmd);
 	default:
 		return -EINVAL;
 	}
-
-out_unlock0:
-	ipc_unlock_object(&shp->shm_perm);
-out_unlock1:
-	rcu_read_unlock();
-	return err;
 }
 
 /*