#8739: upgrade smtpd to RFC 5321 and 1870.

smtpd now handles EHLO and has infrastructure for extended smtp command mode.
The SIZE extension is also implemented.  In order to support parameters on
MAIL FROM, the RFC 5322 parser from the email package is used to parse the
address "token".

Logging subclasses things and overrides __init__, so it was necessary to
update those __init__ functions in the logging tests to make the logging tests

The original suggestion and patch were by Alberto Trevino.  Juhana Jauhiainen
added the --size argument and SIZE parameter support.  Michele OrrĂ¹ improved
the patch and added more tests.  Dan Boswell conditionalized various bits of
code on whether or not we are in HELO or EHLO mode, as well as some other
improvements and tests.  I finalized the patch and added the address parsing.
diff --git a/Lib/smtpd.py b/Lib/smtpd.py
index 748fcae..778d6d6 100755
--- a/Lib/smtpd.py
+++ b/Lib/smtpd.py
@@ -1,5 +1,5 @@
 #! /usr/bin/env python3
-"""An RFC 2821 smtp proxy.
+"""An RFC 5321 smtp proxy.
 Usage: %(program)s [options] [localhost:localport [remotehost:remoteport]]
@@ -20,6 +20,11 @@
         Use `classname' as the concrete SMTP proxy class.  Uses `PureProxy' by
+    --size limit
+    -s limit
+        Restrict the total size of the incoming message to "limit" number of
+        bytes via the RFC 1870 SIZE extension.  Defaults to 33554432 bytes.
         Turn on debugging prints.
@@ -35,10 +40,9 @@
 and if remoteport is not given, then 25 is used.
 # Overview:
-# This file implements the minimal SMTP protocol as defined in RFC 821.  It
+# This file implements the minimal SMTP protocol as defined in RFC 5321.  It
 # has a hierarchy of classes which implement the backend functionality for the
 # smtpd.  A number of classes are provided:
@@ -66,7 +70,7 @@
 # - support mailbox delivery
 # - alias files
-# - ESMTP
+# - Handle more ESMTP extensions
 # - handle error codes from the backend smtpd
 import sys
@@ -77,12 +81,14 @@
 import socket
 import asyncore
 import asynchat
+import collections
 from warnings import warn
+from email._header_value_parser import get_addr_spec, get_angle_addr
 __all__ = ["SMTPServer","DebuggingServer","PureProxy","MailmanProxy"]
 program = sys.argv[0]
-__version__ = 'Python SMTP proxy version 0.2'
+__version__ = 'Python SMTP proxy version 0.3'
 class Devnull:
@@ -94,9 +100,9 @@
 NEWLINE = '\n'
 def usage(code, msg=''):
     print(__doc__ % globals(), file=sys.stderr)
     if msg:
@@ -104,19 +110,23 @@
 class SMTPChannel(asynchat.async_chat):
     COMMAND = 0
     DATA = 1
-    data_size_limit = 33554432
     command_size_limit = 512
+    command_size_limits = collections.defaultdict(lambda x=command_size_limit: x)
+    command_size_limits.update({
+        'MAIL': command_size_limit + 26,
+        })
+    max_command_size_limit = max(command_size_limits.values())
-    def __init__(self, server, conn, addr):
+    def __init__(self, server, conn, addr, data_size_limit=DATA_SIZE_DEFAULT):
         asynchat.async_chat.__init__(self, conn)
         self.smtp_server = server
         self.conn = conn
         self.addr = addr
+        self.data_size_limit = data_size_limit
         self.received_lines = []
         self.smtp_state = self.COMMAND
         self.seen_greeting = ''
@@ -137,6 +147,7 @@
         print('Peer:', repr(self.peer), file=DEBUGSTREAM)
         self.push('220 %s %s' % (self.fqdn, __version__))
+        self.extended_smtp = False
     # properties for backwards-compatibility
@@ -268,7 +279,7 @@
     def collect_incoming_data(self, data):
         limit = None
         if self.smtp_state == self.COMMAND:
-            limit = self.command_size_limit
+            limit = self.max_command_size_limit
         elif self.smtp_state == self.DATA:
             limit = self.data_size_limit
         if limit and self.num_bytes > limit:
@@ -283,11 +294,7 @@
         print('Data:', repr(line), file=DEBUGSTREAM)
         self.received_lines = []
         if self.smtp_state == self.COMMAND:
-            if self.num_bytes > self.command_size_limit:
-                self.push('500 Error: line too long')
-                self.num_bytes = 0
-                return
-            self.num_bytes = 0
+            sz, self.num_bytes = self.num_bytes, 0
             if not line:
                 self.push('500 Error: bad syntax')
@@ -299,9 +306,14 @@
                 command = line[:i].upper()
                 arg = line[i+1:].strip()
+            max_sz = (self.command_size_limits[command]
+                        if self.extended_smtp else self.command_size_limit)
+            if sz > max_sz:
+                self.push('500 Error: line too long')
+                return
             method = getattr(self, 'smtp_' + command, None)
             if not method:
-                self.push('502 Error: command "%s" not implemented' % command)
+                self.push('500 Error: command "%s" not recognized' % command)
@@ -310,12 +322,12 @@
                 self.push('451 Internal confusion')
                 self.num_bytes = 0
-            if self.num_bytes > self.data_size_limit:
+            if self.data_size_limit and self.num_bytes > self.data_size_limit:
                 self.push('552 Error: Too much mail data')
                 self.num_bytes = 0
             # Remove extraneous carriage returns and de-transparency according
-            # to RFC 821, Section 4.5.2.
+            # to RFC 5321, Section 4.5.2.
             data = []
             for text in line.split('\r\n'):
                 if text and text[0] == '.':
@@ -333,7 +345,7 @@
             self.num_bytes = 0
             if not status:
-                self.push('250 Ok')
+                self.push('250 OK')
@@ -346,66 +358,188 @@
             self.push('503 Duplicate HELO/EHLO')
             self.seen_greeting = arg
+            self.extended_smtp = False
             self.push('250 %s' % self.fqdn)
+    def smtp_EHLO(self, arg):
+        if not arg:
+            self.push('501 Syntax: EHLO hostname')
+            return
+        if self.seen_greeting:
+            self.push('503 Duplicate HELO/EHLO')
+        else:
+            self.seen_greeting = arg
+            self.extended_smtp = True
+            self.push('250-%s' % self.fqdn)
+            if self.data_size_limit:
+                self.push('250-SIZE %s' % self.data_size_limit)
+            self.push('250 HELP')
     def smtp_NOOP(self, arg):
         if arg:
             self.push('501 Syntax: NOOP')
-            self.push('250 Ok')
+            self.push('250 OK')
     def smtp_QUIT(self, arg):
         # args is ignored
         self.push('221 Bye')
-    # factored
-    def __getaddr(self, keyword, arg):
-        address = None
+    def _strip_command_keyword(self, keyword, arg):
         keylen = len(keyword)
         if arg[:keylen].upper() == keyword:
-            address = arg[keylen:].strip()
-            if not address:
-                pass
-            elif address[0] == '<' and address[-1] == '>' and address != '<>':
-                # Addresses can be in the form <person@dom.com> but watch out
-                # for null address, e.g. <>
-                address = address[1:-1]
-        return address
+            return arg[keylen:].strip()
+        return ''
+    def _getaddr(self, arg):
+        if not arg:
+            return '', ''
+        if arg.lstrip().startswith('<'):
+            address, rest = get_angle_addr(arg)
+        else:
+            address, rest = get_addr_spec(arg)
+        if not address:
+            return address, rest
+        return address.addr_spec, rest
+    def _getparams(self, params):
+        # Return any parameters that appear to be syntactically valid according
+        # to RFC 1869, ignore all others.  (Postel rule: accept what we can.)
+        params = [param.split('=', 1) for param in params.split()
+                                      if '=' in param]
+        return {k: v for k, v in params if k.isalnum()}
+    def smtp_HELP(self, arg):
+        if arg:
+            extended = ' [SP <mail parameters]'
+            lc_arg = arg.upper()
+            if lc_arg == 'EHLO':
+                self.push('250 Syntax: EHLO hostname')
+            elif lc_arg == 'HELO':
+                self.push('250 Syntax: HELO hostname')
+            elif lc_arg == 'MAIL':
+                msg = '250 Syntax: MAIL FROM: <address>'
+                if self.extended_smtp:
+                    msg += extended
+                self.push(msg)
+            elif lc_arg == 'RCPT':
+                msg = '250 Syntax: RCPT TO: <address>'
+                if self.extended_smtp:
+                    msg += extended
+                self.push(msg)
+            elif lc_arg == 'DATA':
+                self.push('250 Syntax: DATA')
+            elif lc_arg == 'RSET':
+                self.push('250 Syntax: RSET')
+            elif lc_arg == 'NOOP':
+                self.push('250 Syntax: NOOP')
+            elif lc_arg == 'QUIT':
+                self.push('250 Syntax: QUIT')
+            elif lc_arg == 'VRFY':
+                self.push('250 Syntax: VRFY <address>')
+            else:
+                self.push('501 Supported commands: EHLO HELO MAIL RCPT '
+                          'DATA RSET NOOP QUIT VRFY')
+        else:
+            self.push('250 Supported commands: EHLO HELO MAIL RCPT DATA '
+                      'RSET NOOP QUIT VRFY')
+    def smtp_VRFY(self, arg):
+        if arg:
+            address, params = self._getaddr(arg)
+            if address:
+                self.push('252 Cannot VRFY user, but will accept message '
+                          'and attempt delivery')
+            else:
+                self.push('502 Could not VRFY %s' % arg)
+        else:
+            self.push('501 Syntax: VRFY <address>')
     def smtp_MAIL(self, arg):
         if not self.seen_greeting:
             self.push('503 Error: send HELO first');
         print('===> MAIL', arg, file=DEBUGSTREAM)
-        address = self.__getaddr('FROM:', arg) if arg else None
+        syntaxerr = '501 Syntax: MAIL FROM: <address>'
+        if self.extended_smtp:
+            syntaxerr += ' [SP <mail-parameters>]'
+        if arg is None:
+            self.push(syntaxerr)
+            return
+        arg = self._strip_command_keyword('FROM:', arg)
+        address, params = self._getaddr(arg)
         if not address:
-            self.push('501 Syntax: MAIL FROM:<address>')
+            self.push(syntaxerr)
+            return
+        if not self.extended_smtp and params:
+            self.push(syntaxerr)
+            return
+        if not address:
+            self.push(syntaxerr)
         if self.mailfrom:
             self.push('503 Error: nested MAIL command')
+        params = self._getparams(params.upper())
+        if params is None:
+            self.push(syntaxerr)
+            return
+        size = params.pop('SIZE', None)
+        if size:
+            if not size.isdigit():
+                self.push(syntaxerr)
+                return
+            elif self.data_size_limit and int(size) > self.data_size_limit:
+                self.push('552 Error: message size exceeds fixed maximum message size')
+                return
+        if len(params.keys()) > 0:
+            self.push('555 MAIL FROM parameters not recognized or not implemented')
+            return
         self.mailfrom = address
         print('sender:', self.mailfrom, file=DEBUGSTREAM)
-        self.push('250 Ok')
+        self.push('250 OK')
     def smtp_RCPT(self, arg):
         if not self.seen_greeting:
             self.push('503 Error: send HELO first');
         print('===> RCPT', arg, file=DEBUGSTREAM)
         if not self.mailfrom:
             self.push('503 Error: need MAIL command')
-        address = self.__getaddr('TO:', arg) if arg else None
+        syntaxerr = '501 Syntax: RCPT TO: <address>'
+        if self.extended_smtp:
+            syntaxerr += ' [SP <mail-parameters>]'
+        if arg is None:
+            self.push(syntaxerr)
+            return
+        arg = self._strip_command_keyword('TO:', arg)
+        address, params = self._getaddr(arg)
+        if not address:
+            self.push(syntaxerr)
+            return
+        if params:
+            if self.extended_smtp:
+                params = self._getparams(params.upper())
+                if params is None:
+                    self.push(syntaxerr)
+                    return
+            else:
+                self.push(syntaxerr)
+                return
+        if not address:
+            self.push(syntaxerr)
+            return
+        if params and len(params.keys()) > 0:
+            self.push('555 RCPT TO parameters not recognized or not implemented')
+            return
         if not address:
             self.push('501 Syntax: RCPT TO: <address>')
         print('recips:', self.rcpttos, file=DEBUGSTREAM)
-        self.push('250 Ok')
+        self.push('250 OK')
     def smtp_RSET(self, arg):
         if arg:
@@ -416,13 +550,12 @@
         self.rcpttos = []
         self.received_data = ''
         self.smtp_state = self.COMMAND
-        self.push('250 Ok')
+        self.push('250 OK')
     def smtp_DATA(self, arg):
         if not self.seen_greeting:
             self.push('503 Error: send HELO first');
         if not self.rcpttos:
             self.push('503 Error: need RCPT command')
@@ -433,15 +566,20 @@
         self.push('354 End data with <CR><LF>.<CR><LF>')
+    # Commands that have not been implemented
+    def smtp_EXPN(self, arg):
+        self.push('502 EXPN not implemented')
 class SMTPServer(asyncore.dispatcher):
     # SMTPChannel class to use for managing client connections
     channel_class = SMTPChannel
-    def __init__(self, localaddr, remoteaddr):
+    def __init__(self, localaddr, remoteaddr,
+                 data_size_limit=DATA_SIZE_DEFAULT):
         self._localaddr = localaddr
         self._remoteaddr = remoteaddr
+        self.data_size_limit = data_size_limit
             self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -459,7 +597,7 @@
     def handle_accepted(self, conn, addr):
         print('Incoming connection from %s' % repr(addr), file=DEBUGSTREAM)
-        channel = self.channel_class(self, conn, addr)
+        channel = self.channel_class(self, conn, addr, self.data_size_limit)
     # API for "doing something useful with the message"
     def process_message(self, peer, mailfrom, rcpttos, data):
@@ -487,7 +625,6 @@
         raise NotImplementedError
 class DebuggingServer(SMTPServer):
     # Do something with the gathered message
     def process_message(self, peer, mailfrom, rcpttos, data):
@@ -503,7 +640,6 @@
         print('------------ END MESSAGE ------------')
 class PureProxy(SMTPServer):
     def process_message(self, peer, mailfrom, rcpttos, data):
         lines = data.split('\n')
@@ -544,7 +680,6 @@
         return refused
 class MailmanProxy(PureProxy):
     def process_message(self, peer, mailfrom, rcpttos, data):
         from io import StringIO
@@ -623,19 +758,18 @@
                 msg.Enqueue(mlist, torequest=1)
 class Options:
     setuid = 1
     classname = 'PureProxy'
+    size_limit = None
 def parseargs():
     global DEBUGSTREAM
         opts, args = getopt.getopt(
-            sys.argv[1:], 'nVhc:d',
-            ['class=', 'nosetuid', 'version', 'help', 'debug'])
+            sys.argv[1:], 'nVhc:s:d',
+            ['class=', 'nosetuid', 'version', 'help', 'size=', 'debug'])
     except getopt.error as e:
         usage(1, e)
@@ -652,6 +786,13 @@
             options.classname = arg
         elif opt in ('-d', '--debug'):
             DEBUGSTREAM = sys.stderr
+        elif opt in ('-s', '--size'):
+            try:
+                int_size = int(arg)
+                options.size_limit = int_size
+            except:
+                print('Invalid size: ' + arg, file=sys.stderr)
+                sys.exit(1)
     # parse the rest of the arguments
     if len(args) < 1:
@@ -686,7 +827,6 @@
     return options
 if __name__ == '__main__':
     options = parseargs()
     # Become nobody
@@ -699,7 +839,8 @@
         import __main__ as mod
     class_ = getattr(mod, classname)
     proxy = class_((options.localhost, options.localport),
-                   (options.remotehost, options.remoteport))
+                   (options.remotehost, options.remoteport),
+                   options.size_limit)
     if options.setuid:
             import pwd