cifs: push rfc1002 generation down the stack

Move the generation of the 4 byte length field down the stack and
generate it immediately before we start writing the data to the socket.

Signed-off-by: Ronnie Sahlberg <lsahlber@redhat.com>
Signed-off-by: Aurelien Aptel <aaptel@suse.com>
Signed-off-by: Steve French <smfrench@gmail.com>
diff --git a/fs/cifs/cifsencrypt.c b/fs/cifs/cifsencrypt.c
index 937251c..f23ff84 100644
--- a/fs/cifs/cifsencrypt.c
+++ b/fs/cifs/cifsencrypt.c
@@ -37,7 +37,6 @@
 #include <crypto/aead.h>
 
 int __cifs_calc_signature(struct smb_rqst *rqst,
-			int start,
 			struct TCP_Server_Info *server, char *signature,
 			struct shash_desc *shash)
 {
@@ -45,16 +44,30 @@
 	int rc;
 	struct kvec *iov = rqst->rq_iov;
 	int n_vec = rqst->rq_nvec;
+	int is_smb2 = server->vals->header_preamble_size == 0;
 
-	for (i = start; i < n_vec; i++) {
+	/* iov[0] is actual data and not the rfc1002 length for SMB2+ */
+	if (is_smb2) {
+		rc = crypto_shash_update(shash,
+					 iov[0].iov_base, iov[0].iov_len);
+	} else {
+		if (n_vec < 2 || iov[0].iov_len != 4)
+			return -EIO;
+	}
+
+	for (i = 1; i < n_vec; i++) {
 		if (iov[i].iov_len == 0)
 			continue;
 		if (iov[i].iov_base == NULL) {
 			cifs_dbg(VFS, "null iovec entry\n");
 			return -EIO;
 		}
-		if (i == 1 && iov[1].iov_len <= 4)
-			break; /* nothing to sign or corrupt header */
+		if (is_smb2) {
+			if (i == 0 && iov[0].iov_len <= 4)
+				break; /* nothing to sign or corrupt header */
+		} else
+			if (i == 1 && iov[1].iov_len <= 4)
+				break; /* nothing to sign or corrupt header */
 		rc = crypto_shash_update(shash,
 					 iov[i].iov_base, iov[i].iov_len);
 		if (rc) {
@@ -118,7 +131,7 @@
 		return rc;
 	}
 
-	return __cifs_calc_signature(rqst, 1, server, signature,
+	return __cifs_calc_signature(rqst, server, signature,
 				     &server->secmech.sdescmd5->shash);
 }
 
diff --git a/fs/cifs/cifsproto.h b/fs/cifs/cifsproto.h
index 3a13b44..4f92182 100644
--- a/fs/cifs/cifsproto.h
+++ b/fs/cifs/cifsproto.h
@@ -544,7 +544,7 @@
 			   struct cifs_sb_info *cifs_sb,
 			   const unsigned char *path, char *pbuf,
 			   unsigned int *pbytes_written);
-int __cifs_calc_signature(struct smb_rqst *rqst, int start,
+int __cifs_calc_signature(struct smb_rqst *rqst,
 			struct TCP_Server_Info *server, char *signature,
 			struct shash_desc *shash);
 enum securityEnum cifs_select_sectype(struct TCP_Server_Info *,
diff --git a/fs/cifs/smb2ops.c b/fs/cifs/smb2ops.c
index 682bcfa..9153407 100644
--- a/fs/cifs/smb2ops.c
+++ b/fs/cifs/smb2ops.c
@@ -2167,7 +2167,7 @@
 		   struct smb_rqst *old_rq)
 {
 	struct smb2_sync_hdr *shdr =
-			(struct smb2_sync_hdr *)old_rq->rq_iov[1].iov_base;
+			(struct smb2_sync_hdr *)old_rq->rq_iov[0].iov_base;
 
 	memset(tr_hdr, 0, sizeof(struct smb2_transform_hdr));
 	tr_hdr->ProtocolId = SMB2_TRANSFORM_PROTO_NUM;
@@ -2187,14 +2187,13 @@
 }
 
 /* Assumes:
- * rqst->rq_iov[0]  is rfc1002 length
- * rqst->rq_iov[1]  is tranform header
- * rqst->rq_iov[2+] data to be encrypted/decrypted
+ * rqst->rq_iov[0]  is tranform header
+ * rqst->rq_iov[1+] data to be encrypted/decrypted
  */
 static struct scatterlist *
 init_sg(struct smb_rqst *rqst, u8 *sign)
 {
-	unsigned int sg_len = rqst->rq_nvec + rqst->rq_npages;
+	unsigned int sg_len = rqst->rq_nvec + rqst->rq_npages + 1;
 	unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 20;
 	struct scatterlist *sg;
 	unsigned int i;
@@ -2205,10 +2204,10 @@
 		return NULL;
 
 	sg_init_table(sg, sg_len);
-	smb2_sg_set_buf(&sg[0], rqst->rq_iov[1].iov_base + 20, assoc_data_len);
-	for (i = 1; i < rqst->rq_nvec - 1; i++)
-		smb2_sg_set_buf(&sg[i], rqst->rq_iov[i+1].iov_base,
-						rqst->rq_iov[i+1].iov_len);
+	smb2_sg_set_buf(&sg[0], rqst->rq_iov[0].iov_base + 20, assoc_data_len);
+	for (i = 1; i < rqst->rq_nvec; i++)
+		smb2_sg_set_buf(&sg[i], rqst->rq_iov[i].iov_base,
+						rqst->rq_iov[i].iov_len);
 	for (j = 0; i < sg_len - 1; i++, j++) {
 		unsigned int len, offset;
 
@@ -2240,11 +2239,10 @@
 	return 1;
 }
 /*
- * Encrypt or decrypt @rqst message. @rqst has the following format:
- * iov[0] - rfc1002 length
- * iov[1] - transform header (associate data),
- * iov[2-N] and pages - data to encrypt.
- * On success return encrypted data in iov[2-N] and pages, leave iov[0-1]
+ * Encrypt or decrypt @rqst message. @rqst[0] has the following format:
+ * iov[0]   - transform header (associate data),
+ * iov[1-N] - SMB2 header and pages - data to encrypt.
+ * On success return encrypted data in iov[1-N] and pages, leave iov[0]
  * untouched.
  */
 static int
@@ -2339,10 +2337,6 @@
 	return rc;
 }
 
-/*
- * This is called from smb_send_rqst. At this point we have the rfc1002
- * header as the first element in the vector.
- */
 static int
 smb3_init_transform_rq(struct TCP_Server_Info *server, struct smb_rqst *new_rq,
 		       struct smb_rqst *old_rq)
@@ -2351,7 +2345,7 @@
 	struct page **pages;
 	struct smb2_transform_hdr *tr_hdr;
 	unsigned int npages = old_rq->rq_npages;
-	unsigned int orig_len = get_rfc1002_length(old_rq->rq_iov[0].iov_base);
+	unsigned int orig_len = 0;
 	int i;
 	int rc = -ENOMEM;
 
@@ -2365,24 +2359,23 @@
 	new_rq->rq_pagesz = old_rq->rq_pagesz;
 	new_rq->rq_tailsz = old_rq->rq_tailsz;
 
+	for (i = 0; i < old_rq->rq_nvec; i++)
+		orig_len += old_rq->rq_iov[i].iov_len;
+
 	for (i = 0; i < npages; i++) {
 		pages[i] = alloc_page(GFP_KERNEL|__GFP_HIGHMEM);
 		if (!pages[i])
 			goto err_free_pages;
 	}
 
-	/* Make space for one extra iov to hold the transform header */
 	iov = kmalloc_array(old_rq->rq_nvec + 1, sizeof(struct kvec),
 			    GFP_KERNEL);
 	if (!iov)
 		goto err_free_pages;
 
-	/* copy all iovs from the old except the 1st one (rfc1002 length) */
-	memcpy(&iov[2], &old_rq->rq_iov[1],
-				sizeof(struct kvec) * (old_rq->rq_nvec - 1));
-	/* copy the rfc1002 iov */
-	iov[0].iov_base = old_rq->rq_iov[0].iov_base;
-	iov[0].iov_len  = old_rq->rq_iov[0].iov_len;
+	/* copy all iovs from the old */
+	memcpy(&iov[1], &old_rq->rq_iov[0],
+				sizeof(struct kvec) * old_rq->rq_nvec);
 
 	new_rq->rq_iov = iov;
 	new_rq->rq_nvec = old_rq->rq_nvec + 1;
@@ -2393,12 +2386,8 @@
 
 	/* fill the 2nd iov with a transform header */
 	fill_transform_hdr(tr_hdr, orig_len, old_rq);
-	new_rq->rq_iov[1].iov_base = tr_hdr;
-	new_rq->rq_iov[1].iov_len = sizeof(struct smb2_transform_hdr);
-
-	/* Update rfc1002 header */
-	inc_rfc1001_len(new_rq->rq_iov[0].iov_base,
-			sizeof(struct smb2_transform_hdr));
+	new_rq->rq_iov[0].iov_base = tr_hdr;
+	new_rq->rq_iov[0].iov_len = sizeof(struct smb2_transform_hdr);
 
 	/* copy pages form the old */
 	for (i = 0; i < npages; i++) {
@@ -2442,7 +2431,7 @@
 		put_page(rqst->rq_pages[i]);
 	kfree(rqst->rq_pages);
 	/* free transform header */
-	kfree(rqst->rq_iov[1].iov_base);
+	kfree(rqst->rq_iov[0].iov_base);
 	kfree(rqst->rq_iov);
 }
 
@@ -2459,19 +2448,17 @@
 		 unsigned int buf_data_size, struct page **pages,
 		 unsigned int npages, unsigned int page_data_size)
 {
-	struct kvec iov[3];
+	struct kvec iov[2];
 	struct smb_rqst rqst = {NULL};
 	int rc;
 
-	iov[0].iov_base = NULL;
-	iov[0].iov_len = 0;
-	iov[1].iov_base = buf;
-	iov[1].iov_len = sizeof(struct smb2_transform_hdr);
-	iov[2].iov_base = buf + sizeof(struct smb2_transform_hdr);
-	iov[2].iov_len = buf_data_size;
+	iov[0].iov_base = buf;
+	iov[0].iov_len = sizeof(struct smb2_transform_hdr);
+	iov[1].iov_base = buf + sizeof(struct smb2_transform_hdr);
+	iov[1].iov_len = buf_data_size;
 
 	rqst.rq_iov = iov;
-	rqst.rq_nvec = 3;
+	rqst.rq_nvec = 2;
 	rqst.rq_pages = pages;
 	rqst.rq_npages = npages;
 	rqst.rq_pagesz = PAGE_SIZE;
@@ -2483,7 +2470,7 @@
 	if (rc)
 		return rc;
 
-	memmove(buf, iov[2].iov_base, buf_data_size);
+	memmove(buf, iov[1].iov_base, buf_data_size);
 
 	server->total_read = buf_data_size + page_data_size;
 
diff --git a/fs/cifs/smb2pdu.c b/fs/cifs/smb2pdu.c
index 328e23a..c48608c 100644
--- a/fs/cifs/smb2pdu.c
+++ b/fs/cifs/smb2pdu.c
@@ -2595,11 +2595,10 @@
 {
 	struct smb2_echo_req *req;
 	int rc = 0;
-	struct kvec iov[2];
+	struct kvec iov[1];
 	struct smb_rqst rqst = { .rq_iov = iov,
-				 .rq_nvec = 2 };
+				 .rq_nvec = 1 };
 	unsigned int total_len;
-	__be32 rfc1002_marker;
 
 	cifs_dbg(FYI, "In echo request\n");
 
@@ -2615,11 +2614,8 @@
 
 	req->sync_hdr.CreditRequest = cpu_to_le16(1);
 
-	iov[0].iov_len = 4;
-	rfc1002_marker = cpu_to_be32(total_len);
-	iov[0].iov_base = &rfc1002_marker;
-	iov[1].iov_len = total_len;
-	iov[1].iov_base = (char *)req;
+	iov[0].iov_len = total_len;
+	iov[0].iov_base = (char *)req;
 
 	rc = cifs_call_async(server, &rqst, NULL, smb2_echo_callback, NULL,
 			     server, CIFS_ECHO_OP);
@@ -2849,10 +2845,9 @@
 	struct smb2_sync_hdr *shdr;
 	struct cifs_io_parms io_parms;
 	struct smb_rqst rqst = { .rq_iov = rdata->iov,
-				 .rq_nvec = 2 };
+				 .rq_nvec = 1 };
 	struct TCP_Server_Info *server;
 	unsigned int total_len;
-	__be32 req_len;
 
 	cifs_dbg(FYI, "%s: offset=%llu bytes=%u\n",
 		 __func__, rdata->offset, rdata->bytes);
@@ -2883,12 +2878,8 @@
 	if (smb3_encryption_required(io_parms.tcon))
 		flags |= CIFS_TRANSFORM_REQ;
 
-	req_len = cpu_to_be32(total_len);
-
-	rdata->iov[0].iov_base = &req_len;
-	rdata->iov[0].iov_len = sizeof(__be32);
-	rdata->iov[1].iov_base = buf;
-	rdata->iov[1].iov_len = total_len;
+	rdata->iov[0].iov_base = buf;
+	rdata->iov[0].iov_len = total_len;
 
 	shdr = (struct smb2_sync_hdr *)buf;
 
@@ -3063,10 +3054,9 @@
 	struct smb2_sync_hdr *shdr;
 	struct cifs_tcon *tcon = tlink_tcon(wdata->cfile->tlink);
 	struct TCP_Server_Info *server = tcon->ses->server;
-	struct kvec iov[2];
+	struct kvec iov[1];
 	struct smb_rqst rqst = { };
 	unsigned int total_len;
-	__be32 rfc1002_marker;
 
 	rc = smb2_plain_req_init(SMB2_WRITE, tcon, (void **) &req, &total_len);
 	if (rc) {
@@ -3138,15 +3128,11 @@
 		v1->length = cpu_to_le32(wdata->mr->mr->length);
 	}
 #endif
-	/* 4 for rfc1002 length field and 1 for Buffer */
-	iov[0].iov_len = 4;
-	rfc1002_marker = cpu_to_be32(total_len - 1 + wdata->bytes);
-	iov[0].iov_base = &rfc1002_marker;
-	iov[1].iov_len = total_len - 1;
-	iov[1].iov_base = (char *)req;
+	iov[0].iov_len = total_len - 1;
+	iov[0].iov_base = (char *)req;
 
 	rqst.rq_iov = iov;
-	rqst.rq_nvec = 2;
+	rqst.rq_nvec = 1;
 	rqst.rq_pages = wdata->pages;
 	rqst.rq_offset = wdata->page_offset;
 	rqst.rq_npages = wdata->nr_pages;
@@ -3154,7 +3140,7 @@
 	rqst.rq_tailsz = wdata->tailsz;
 #ifdef CONFIG_CIFS_SMB_DIRECT
 	if (wdata->mr) {
-		iov[1].iov_len += sizeof(struct smbd_buffer_descriptor_v1);
+		iov[0].iov_len += sizeof(struct smbd_buffer_descriptor_v1);
 		rqst.rq_npages = 0;
 	}
 #endif
diff --git a/fs/cifs/smb2transport.c b/fs/cifs/smb2transport.c
index 349d5cc..51b9437 100644
--- a/fs/cifs/smb2transport.c
+++ b/fs/cifs/smb2transport.c
@@ -171,9 +171,7 @@
 	unsigned char smb2_signature[SMB2_HMACSHA256_SIZE];
 	unsigned char *sigptr = smb2_signature;
 	struct kvec *iov = rqst->rq_iov;
-	int iov_hdr_index = rqst->rq_nvec > 1 ? 1 : 0;
-	struct smb2_sync_hdr *shdr =
-		(struct smb2_sync_hdr *)iov[iov_hdr_index].iov_base;
+	struct smb2_sync_hdr *shdr = (struct smb2_sync_hdr *)iov[0].iov_base;
 	struct cifs_ses *ses;
 
 	ses = smb2_find_smb_ses(server, shdr->SessionId);
@@ -204,7 +202,7 @@
 		return rc;
 	}
 
-	rc = __cifs_calc_signature(rqst, iov_hdr_index,  server, sigptr,
+	rc = __cifs_calc_signature(rqst, server, sigptr,
 		&server->secmech.sdeschmacsha256->shash);
 
 	if (!rc)
@@ -414,9 +412,7 @@
 	unsigned char smb3_signature[SMB2_CMACAES_SIZE];
 	unsigned char *sigptr = smb3_signature;
 	struct kvec *iov = rqst->rq_iov;
-	int iov_hdr_index = rqst->rq_nvec > 1 ? 1 : 0;
-	struct smb2_sync_hdr *shdr =
-		(struct smb2_sync_hdr *)iov[iov_hdr_index].iov_base;
+	struct smb2_sync_hdr *shdr = (struct smb2_sync_hdr *)iov[0].iov_base;
 	struct cifs_ses *ses;
 
 	ses = smb2_find_smb_ses(server, shdr->SessionId);
@@ -447,7 +443,7 @@
 		return rc;
 	}
 
-	rc = __cifs_calc_signature(rqst, iov_hdr_index, server, sigptr,
+	rc = __cifs_calc_signature(rqst, server, sigptr,
 				   &server->secmech.sdesccmacaes->shash);
 
 	if (!rc)
@@ -462,7 +458,7 @@
 {
 	int rc = 0;
 	struct smb2_sync_hdr *shdr =
-			(struct smb2_sync_hdr *)rqst->rq_iov[1].iov_base;
+			(struct smb2_sync_hdr *)rqst->rq_iov[0].iov_base;
 
 	if (!(shdr->Flags & SMB2_FLAGS_SIGNED) ||
 	    server->tcpStatus == CifsNeedNegotiate)
@@ -635,7 +631,7 @@
 {
 	int rc;
 	struct smb2_sync_hdr *shdr =
-			(struct smb2_sync_hdr *)rqst->rq_iov[1].iov_base;
+			(struct smb2_sync_hdr *)rqst->rq_iov[0].iov_base;
 	struct mid_q_entry *mid;
 
 	smb2_seq_num_into_buf(ses->server, shdr);
@@ -656,7 +652,7 @@
 {
 	int rc;
 	struct smb2_sync_hdr *shdr =
-			(struct smb2_sync_hdr *)rqst->rq_iov[1].iov_base;
+			(struct smb2_sync_hdr *)rqst->rq_iov[0].iov_base;
 	struct mid_q_entry *mid;
 
 	smb2_seq_num_into_buf(server, shdr);
diff --git a/fs/cifs/transport.c b/fs/cifs/transport.c
index 1f1a68f..63f25f9 100644
--- a/fs/cifs/transport.c
+++ b/fs/cifs/transport.c
@@ -241,13 +241,14 @@
 	int rc;
 	struct kvec *iov = rqst->rq_iov;
 	int n_vec = rqst->rq_nvec;
-	unsigned int smb_buf_length = get_rfc1002_length(iov[0].iov_base);
-	unsigned long send_length;
+	unsigned int send_length;
 	unsigned int i;
 	size_t total_len = 0, sent, size;
 	struct socket *ssocket = server->ssocket;
 	struct msghdr smb_msg;
 	int val = 1;
+	__be32 rfc1002_marker;
+
 	if (cifs_rdma_enabled(server) && server->smbd_conn) {
 		rc = smbd_send(server->smbd_conn, rqst);
 		goto smbd_done;
@@ -255,26 +256,34 @@
 	if (ssocket == NULL)
 		return -ENOTSOCK;
 
-	/* sanity check send length */
 	send_length = rqst_len(rqst);
-	if (send_length != smb_buf_length + 4) {
-		WARN(1, "Send length mismatch(send_length=%lu smb_buf_length=%u)\n",
-			send_length, smb_buf_length);
-		return -EIO;
-	}
-
-	if (n_vec < 2)
-		return -EIO;
-
-	cifs_dbg(FYI, "Sending smb: smb_len=%u\n", smb_buf_length);
-	dump_smb(iov[0].iov_base, iov[0].iov_len);
-	dump_smb(iov[1].iov_base, iov[1].iov_len);
+	rfc1002_marker = cpu_to_be32(send_length);
 
 	/* cork the socket */
 	kernel_setsockopt(ssocket, SOL_TCP, TCP_CORK,
 				(char *)&val, sizeof(val));
 
 	size = 0;
+	/* Generate a rfc1002 marker for SMB2+ */
+	if (server->vals->header_preamble_size == 0) {
+		struct kvec hiov = {
+			.iov_base = &rfc1002_marker,
+			.iov_len  = 4
+		};
+		iov_iter_kvec(&smb_msg.msg_iter, WRITE | ITER_KVEC, &hiov,
+			      1, 4);
+		rc = smb_send_kvec(server, &smb_msg, &sent);
+		if (rc < 0)
+			goto uncork;
+
+		total_len += sent;
+		send_length += 4;
+	}
+
+	cifs_dbg(FYI, "Sending smb: smb_len=%u\n", send_length);
+	dump_smb(iov[0].iov_base, iov[0].iov_len);
+	dump_smb(iov[1].iov_base, iov[1].iov_len);
+
 	for (i = 0; i < n_vec; i++)
 		size += iov[i].iov_len;
 
@@ -308,9 +317,9 @@
 	kernel_setsockopt(ssocket, SOL_TCP, TCP_CORK,
 				(char *)&val, sizeof(val));
 
-	if ((total_len > 0) && (total_len != smb_buf_length + 4)) {
+	if ((total_len > 0) && (total_len != send_length)) {
 		cifs_dbg(FYI, "partial send (wanted=%u sent=%zu): terminating session\n",
-			 smb_buf_length + 4, total_len);
+			 send_length, total_len);
 		/*
 		 * If we have only sent part of an SMB then the next SMB could
 		 * be taken as the remainder of this one. We need to kill the
@@ -730,7 +739,6 @@
 	 * to the same server. We may make this configurable later or
 	 * use ses->maxReq.
 	 */
-
 	rc = wait_for_free_request(ses->server, timeout, optype);
 	if (rc)
 		return rc;
@@ -766,8 +774,8 @@
 
 #ifdef CONFIG_CIFS_SMB311
 	if ((ses->status == CifsNew) || (optype & CIFS_NEG_OP))
-		smb311_update_preauth_hash(ses, rqst->rq_iov+1,
-					   rqst->rq_nvec-1);
+		smb311_update_preauth_hash(ses, rqst->rq_iov,
+					   rqst->rq_nvec);
 #endif
 
 	if (timeout == CIFS_ASYNC_OP)
@@ -812,8 +820,8 @@
 #ifdef CONFIG_CIFS_SMB311
 	if ((ses->status == CifsNew) || (optype & CIFS_NEG_OP)) {
 		struct kvec iov = {
-			.iov_base = buf,
-			.iov_len = midQ->resp_buf_size
+			.iov_base = resp_iov->iov_base,
+			.iov_len = resp_iov->iov_len
 		};
 		smb311_update_preauth_hash(ses, &iov, 1);
 	}
@@ -879,39 +887,13 @@
 	       const int flags, struct kvec *resp_iov)
 {
 	struct smb_rqst rqst;
-	struct kvec s_iov[CIFS_MAX_IOV_SIZE], *new_iov;
 	int rc;
-	int i;
-	__u32 count;
-	__be32 rfc1002_marker;
-
-	if (n_vec + 1 > CIFS_MAX_IOV_SIZE) {
-		new_iov = kmalloc_array(n_vec + 1, sizeof(struct kvec),
-					GFP_KERNEL);
-		if (!new_iov)
-			return -ENOMEM;
-	} else
-		new_iov = s_iov;
-
-	/* 1st iov is an RFC1002 Session Message length */
-	memcpy(new_iov + 1, iov, (sizeof(struct kvec) * n_vec));
-
-	count = 0;
-	for (i = 1; i < n_vec + 1; i++)
-		count += new_iov[i].iov_len;
-
-	rfc1002_marker = cpu_to_be32(count);
-
-	new_iov[0].iov_base = &rfc1002_marker;
-	new_iov[0].iov_len = 4;
 
 	memset(&rqst, 0, sizeof(struct smb_rqst));
-	rqst.rq_iov = new_iov;
-	rqst.rq_nvec = n_vec + 1;
+	rqst.rq_iov = iov;
+	rqst.rq_nvec = n_vec;
 
 	rc = cifs_send_recv(xid, ses, &rqst, resp_buf_type, flags, resp_iov);
-	if (n_vec + 1 > CIFS_MAX_IOV_SIZE)
-		kfree(new_iov);
 	return rc;
 }