import ssl import socket import asyncio import logging import collections import time from asyncio import sslproto from email._header_value_parser import get_addr_spec, get_angle_addr from email.errors import HeaderParseError from public import public from warnings import warn __version__ = '1.2+' __ident__ = 'Python SMTP {}'.format(__version__) log = logging.getLogger('mail.log') DATA_SIZE_DEFAULT = 33554432 EMPTYBYTES = b'' NEWLINE = '\n' @public class Session: def __init__(self, loop): self.peer = None self.ssl = None self.host_name = None self.extended_smtp = False self.loop = loop self.started_at = int(time.time() * 1000) def get_duration(self): now = int(time.time() * 1000) return now - self.started_at @public class Envelope: def __init__(self): self.mail_from = None self.mail_options = [] self.smtp_utf8 = False self.content = None self.original_content = None self.rcpt_tos = [] self.rcpt_options = [] # This is here to enable debugging output when the -E option is given to the # unit test suite. In that case, this function is mocked to set the debug # level on the loop (as if PYTHONASYNCIODEBUG=1 were set). def make_loop(): return asyncio.get_event_loop() def syntax(text, extended=None, when=None): def decorator(f): f.__smtp_syntax__ = text f.__smtp_syntax_extended__ = extended f.__smtp_syntax_when__ = when return f return decorator @public class AioSMTP(asyncio.StreamReaderProtocol): command_size_limit = 512 command_size_limits = collections.defaultdict( lambda x=command_size_limit: x) def __init__(self, handler, *, data_size_limit=DATA_SIZE_DEFAULT, enable_SMTPUTF8=False, decode_data=False, hostname=None, ident=None, tls_context=None, require_starttls=False, timeout=300, global_timeout=300, command_call_limit=10, loop=None): self.__ident__ = ident or __ident__ self.loop = loop if loop else make_loop() super().__init__( asyncio.StreamReader(loop=self.loop), client_connected_cb=self._client_connected_cb, loop=self.loop) self.event_handler = handler self.data_size_limit = data_size_limit self.enable_SMTPUTF8 = enable_SMTPUTF8 self._decode_data = decode_data self.command_size_limits.clear() if hostname: self.hostname = hostname else: self.hostname = socket.getfqdn() self.tls_context = tls_context if tls_context: # Through rfc3207 part 4.1 certificate checking is part of SMTP # protocol, not SSL layer. self.tls_context.check_hostname = False self.tls_context.verify_mode = ssl.CERT_NONE self.require_starttls = tls_context and require_starttls self._timeout_duration = timeout self._timeout_handle = None self._global_timeout_duration = global_timeout self._global_timeout_handle = None self._tls_handshake_okay = True self._tls_protocol = None self._original_transport = None self.session = None self.envelope = None self.transport = None self._handler_coroutine = None # This will be used for call limiting the commands self._command_call_count = {} self._command_call_limit = None self._command_call_limit_global = command_call_limit def _create_session(self): return Session(self.loop) def _create_envelope(self): return Envelope() async def _call_handler_hook(self, command, *args): hook = getattr(self.event_handler, 'handle_' + command, None) if hook is not None: return await hook(self, self.session, self.envelope, *args) return None @property def max_command_size_limit(self): try: return max(self.command_size_limits.values()) except ValueError: return self.command_size_limit def connection_made(self, transport): # Reset state due to rfc3207 part 4.2. self._set_rset_state() self.session = self._create_session() self.session.peer = transport.get_extra_info('peername') seen_starttls = (self._original_transport is not None) if self.transport is not None and seen_starttls: # It is STARTTLS connection over normal connection. self._reader._transport = transport self._writer._transport = transport self.transport = transport # Do SSL certificate checking as rfc3207 part 4.1 says. Why is # _extra a protected attribute? self.session.ssl = self._tls_protocol._extra handler = getattr(self.event_handler, 'handle_STARTTLS', None) if handler is None: self._tls_handshake_okay = True else: self._tls_handshake_okay = handler( self, self.session, self.envelope) else: super().connection_made(transport) self.transport = transport log.info('Peer: %r', self.session.peer) # Process the client's requests. self._handler_coroutine = self.loop.create_task( self._handle_client()) def connection_lost(self, error): log.info('%r connection lost', self.session.peer) self._cancel_timeouts() # If STARTTLS was issued, then our transport is the SSL protocol # transport, and we need to close the original transport explicitly, # otherwise an unexpected eof_received() will be called *after* the # connection_lost(). At that point the stream reader will already be # destroyed and we'll get a traceback in super().eof_received() below. if self._original_transport is not None: self._original_transport.close() super().connection_lost(error) self._handler_coroutine.cancel() self.transport = None def eof_received(self): log.info('%r EOF received', self.session.peer) self._cancel_timeouts() self._handler_coroutine.cancel() if self.session.ssl is not None: # pragma: nomswin # If STARTTLS was issued, return False, because True has no effect # on an SSL transport and raises a warning. Our superclass has no # way of knowing we switched to SSL so it might return True. # # This entire method seems not to be called during any of the # starttls tests on Windows. I don't really know why, but it # causes these lines to fail coverage, hence the `nomswin` pragma # above. return False return super().eof_received() def _cancel_timeouts(self): if self._timeout_handle is not None: self._timeout_handle.cancel() self._timeout_handle = None if self._global_timeout_handle is not None: self._global_timeout_handle.cancel() self._timeout_handle = None def _reset_timeout(self): log.info('_reset_timeout') if self._timeout_handle is not None: self._timeout_handle.cancel() self._timeout_handle = self.loop.call_later( self._timeout_duration, self._timeout_cb ) def _reset_global_timeout(self): log.info('_reset_global_timeout') if self._global_timeout_handle is not None: self._global_timeout_handle.cancel() self._global_timeout_handle = self.loop.call_later( self._global_timeout_duration, self._timeout_cb ) def _timeout_cb(self): # In some case, the connection was lost before calling the previous line # So we re-check if transport is not None, to close it. if self.transport: # Calling close() on the transport will trigger connection_lost(), # which gracefully closes the SSL transport if required and cleans # up state. self.transport.close() """ def _reset_timeout(self): log.info('_reset_timeout') if self._timeout_handle is not None: self._timeout_handle.cancel() self._timeout_handle = self._call_later_sync( self._timeout_duration, self._timeout_cb ) def _reset_global_timeout(self): log.info('_reset_global_timeout') if self._global_timeout_handle is not None: self._global_timeout_handle.cancel() self._global_timeout_handle = self._call_later_sync( self._global_timeout_duration, self._timeout_cb ) def _call_later_sync(self, duration, callback): return asyncio.ensure_future(callback(duration), loop=self.loop) async def _timeout_cb(self, duration): await asyncio.sleep(duration) log.info('%r connection timeout', self.session.peer) if self.transport: try: result = await self.push("550 Timeout error.") except Exception as e: # Connection was lost pass # In some case, the connection was lost before calling the previous line # So we re-check if transport is not None, to close it. if self.transport: # Calling close() on the transport will trigger connection_lost(), # which gracefully closes the SSL transport if required and cleans # up state. self.transport.close() """ def _client_connected_cb(self, reader, writer): # This is redundant since we subclass StreamReaderProtocol, but I like # the shorter names. self._reader = reader self._writer = writer def _set_post_data_state(self): """Reset state variables to their post-DATA state.""" self.envelope = self._create_envelope() def _set_rset_state(self): """Reset all state variables except the greeting.""" self._set_post_data_state() self._reset_timeout() def is_command_call_limited(self, command): """Check if that specific command has been called too many times""" if command not in self._command_call_count: self._command_call_count[command] = 0 self._command_call_count[command] += 1 limit = self._command_call_limit_global if isinstance(self._command_call_limit, dict) and command in self._command_call_limit: limit = self._command_call_limit[command] return self._command_call_count[command] > limit def set_command_call_limit(self, options): """Set a limit per command options. For instance, RCPT could accept more than MAIL""" assert isinstance(options, dict) self._command_call_limit = options async def push(self, status): response = bytes( status + '\r\n', 'utf-8' if self.enable_SMTPUTF8 else 'ascii') self._writer.write(response) log.debug(response) await self._writer.drain() async def handle_exception(self, error): if hasattr(self.event_handler, 'handle_exception'): status = await self.event_handler.handle_exception(error) return status else: log.exception('SMTP session exception') status = '500 Error: ({}) {}'.format( error.__class__.__name__, str(error)) return status async def _handle_client(self): log.info('%r handling connection', self.session.peer) await self.push('220 {} {}'.format(self.hostname, self.__ident__)) while self.transport is not None: # pragma: nobranch # XXX Put the line limit stuff into the StreamReader? try: line = await self._reader.readline() log.debug('_handle_client readline: %s', line) # XXX this rstrip may not completely preserve old behavior. line = line.rstrip(b'\r\n') log.info('%r Data: %s', self.session.peer, line) if not line: await self.push('500 Error: bad syntax') continue i = line.find(b' ') # Decode to string only the command name part, which must be # ASCII as per RFC. If there is an argument, it is decoded to # UTF-8/surrogateescape so that non-UTF-8 data can be # re-encoded back to the original bytes when the SMTP command # is handled. if i < 0: try: command = line.upper().decode(encoding='ascii') except UnicodeDecodeError: await self.push('500 Error: bad syntax') continue arg = None else: try: command = line[:i].upper().decode(encoding='ascii') except UnicodeDecodeError: await self.push('500 Error: bad syntax') continue arg = line[i+1:].strip() # Remote SMTP servers can send us UTF-8 content despite # whether they've declared to do so or not. Some old # servers can send 8-bit data. Use surrogateescape so # that the fidelity of the decoding is preserved, and the # original bytes can be retrieved. if self.enable_SMTPUTF8: arg = str( arg, encoding='utf-8', errors='surrogateescape') else: try: arg = str(arg, encoding='ascii', errors='strict') except UnicodeDecodeError: # This happens if enable_SMTPUTF8 is false, meaning # that the server explicitly does not want to # accept non-ASCII, but the client ignores that and # sends non-ASCII anyway. await self.push('500 Error: strict ASCII mode') # Should we await self.handle_exception()? continue max_sz = (self.command_size_limits[command] if self.session.extended_smtp else self.command_size_limit) if len(line) > max_sz: await self.push('500 Error: line too long') continue if not self._tls_handshake_okay and command != 'QUIT': await self.push( '554 Command refused due to lack of security') continue if (self.require_starttls and not self._tls_protocol and command not in ['EHLO', 'STARTTLS', 'QUIT']): # RFC3207 part 4 await self.push('530 Must issue a STARTTLS command first') continue if self.is_command_call_limited(command): await self.push( '500 Error: command "%s" sent too many times' % command) self.transport.close() continue method = getattr(self, 'smtp_' + command, None) if method is None: await self.push( '500 Error: command "%s" not recognized' % command) continue # Received a valid command, reset the timer. self._reset_timeout() await method(arg) except asyncio.CancelledError: # The connection got reset during the DATA command. # XXX If handler method raises ConnectionResetError, we should # verify that it was actually self._reader that was reset. log.info('Connection lost during _handle_client()') self._writer.close() raise except Exception as error: try: status = await self.handle_exception(error) await self.push(status) except Exception as error: try: log.exception('Exception in handle_exception()') status = '500 Error: ({}) {}'.format( error.__class__.__name__, str(error)) except Exception: status = '500 Error: Cannot describe error' await self.push(status) # SMTP and ESMTP commands @syntax('HELO hostname') async def smtp_HELO(self, hostname): if not hostname: await self.push('501 Syntax: HELO hostname') return self._set_rset_state() self.session.extended_smtp = False status = await self._call_handler_hook('HELO', hostname) if status is None: self.session.host_name = hostname status = '250 {}'.format(self.hostname) await self.push(status) @syntax('EHLO hostname') async def smtp_EHLO(self, hostname): if not hostname: await self.push('501 Syntax: EHLO hostname') return response = [] self._set_rset_state() self.session.extended_smtp = True response.append('250-%s' % self.hostname) if self.data_size_limit: response.append('250-SIZE %s' % self.data_size_limit) self.command_size_limits['MAIL'] += 26 if not self._decode_data: response.append('250-8BITMIME') if self.enable_SMTPUTF8: response.append('250-SMTPUTF8') self.command_size_limits['MAIL'] += 10 if self.tls_context and not self._tls_protocol: response.append('250-STARTTLS') if hasattr(self, 'ehlo_hook'): warn('Use handler.handle_EHLO() instead of .ehlo_hook()', DeprecationWarning) await self.ehlo_hook() status = await self._call_handler_hook('EHLO', hostname) if status is None: self.session.host_name = hostname response.append('250 HELP') for r in response: await self.push(r) else: await self.push(status) @syntax('NOOP [ignored]') async def smtp_NOOP(self, arg): status = await self._call_handler_hook('NOOP', arg) await self.push('250 OK' if status is None else status) @syntax('QUIT') async def smtp_QUIT(self, arg): if arg: await self.push('501 Syntax: QUIT') else: status = await self._call_handler_hook('QUIT') await self.push('221 Bye' if status is None else status) self._handler_coroutine.cancel() self.transport.close() @syntax('STARTTLS', when='tls_context') async def smtp_STARTTLS(self, arg): log.info('%r STARTTLS', self.session.peer) if arg: await self.push('501 Syntax: STARTTLS') return if not self.tls_context: await self.push('454 TLS not available') return await self.push('220 Ready to start TLS') # Create SSL layer. self._tls_protocol = sslproto.SSLProtocol( self.loop, self, self.tls_context, None, server_side=True) # Reconfigure transport layer. Keep a reference to the original # transport so that we can close it explicitly when the connection is # lost. XXX BaseTransport.set_protocol() was added in Python 3.5.3 :( self._original_transport = self.transport self._original_transport._protocol = self._tls_protocol # Reconfigure the protocol layer. Why is the app transport a protected # property, if it MUST be used externally? self.transport = self._tls_protocol._app_transport self._tls_protocol.connection_made(self._original_transport) def _strip_command_keyword(self, keyword, arg): keylen = len(keyword) if arg[:keylen].upper() == keyword: return arg[keylen:].strip() return None def _getaddr(self, arg): if not arg: return '', '' if arg.lstrip().startswith('<'): address, rest = get_angle_addr(arg) else: address, rest = get_addr_spec(arg) try: address = address.addr_spec except IndexError: # Workaround http://bugs.python.org/issue27931 address = None return address, rest def _getparams(self, params): # Return params as dictionary. Return None if not all parameters # appear to be syntactically valid according to RFC 1869. result = {} for param in params: param, eq, value = param.partition('=') if not param.isalnum() or eq and not value: return None result[param] = value if eq else True return result def _syntax_available(self, method): if getattr(method, '__smtp_syntax__', None) is None: return False if method.__smtp_syntax_when__: return bool(getattr(self, method.__smtp_syntax_when__)) return True @syntax('HELP [command]') async def smtp_HELP(self, arg): code = 250 if arg: method = getattr(self, 'smtp_' + arg.upper(), None) if method and self._syntax_available(method): help_str = method.__smtp_syntax__ if (self.session.extended_smtp and method.__smtp_syntax_extended__): help_str += method.__smtp_syntax_extended__ await self.push('250 Syntax: ' + help_str) return code = 501 commands = [] for name in dir(self): if not name.startswith('smtp_'): continue method = getattr(self, name) if self._syntax_available(method): commands.append(name.lstrip('smtp_')) commands.sort() await self.push( '{} Supported commands: {}'.format(code, ' '.join(commands))) @syntax('VRFY
') async def smtp_VRFY(self, arg): if arg: try: address, params = self._getaddr(arg) except HeaderParseError: address = None if address is None: await self.push('502 Could not VRFY %s' % arg) else: status = await self._call_handler_hook('VRFY', address) await self.push( '252 Cannot VRFY user, but will accept message ' 'and attempt delivery' if status is None else status) else: await self.push('501 Syntax: VRFY ') @syntax('MAIL FROM: ', extended=' [SP