# vim: ts=4 sw=4 et ai: # -*- coding: utf8 -*- """This module implements all contexts for state handling during uploads and downloads, the main interface to which being the TftpContext base class. The concept is simple. Each context object represents a single upload or download, and the state object in the context object represents the current state of that transfer. The state object has a handle() method that expects the next packet in the transfer, and returns a state object until the transfer is complete, at which point it returns None. That is, unless there is a fatal error, in which case a TftpException is returned instead.""" from .TftpShared import * from .TftpPacketTypes import * from .TftpPacketFactory import TftpPacketFactory from .TftpStates import * import socket import time import sys import os import logging log = logging.getLogger('tftpy.TftpContext') ############################################################################### # Utility classes ############################################################################### class TftpMetrics(object): """A class representing metrics of the transfer.""" def __init__(self): # Bytes transferred self.bytes = 0 # Bytes re-sent self.resent_bytes = 0 # Duplicate packets received self.dups = {} self.dupcount = 0 # Times self.start_time = 0 self.end_time = 0 self.duration = 0 # Rates self.bps = 0 self.kbps = 0 # Generic errors self.errors = 0 def compute(self): # Compute transfer time self.duration = self.end_time - self.start_time if self.duration == 0: self.duration = 1 log.debug("TftpMetrics.compute: duration is %s", self.duration) self.bps = (self.bytes * 8.0) / self.duration self.kbps = self.bps / 1024.0 log.debug("TftpMetrics.compute: kbps is %s", self.kbps) for key in self.dups: self.dupcount += self.dups[key] def add_dup(self, pkt): """This method adds a dup for a packet to the metrics.""" log.debug("Recording a dup of %s", pkt) s = str(pkt) if s in self.dups: self.dups[s] += 1 else: self.dups[s] = 1 tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached") ############################################################################### # Context classes ############################################################################### class TftpContext(object): """The base class of the contexts.""" def __init__(self, host, port, timeout, localip = ""): """Constructor for the base context, setting shared instance variables.""" self.file_to_transfer = None self.fileobj = None self.options = None self.packethook = None self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) if localip != "": self.sock.bind((localip, 0)) self.sock.settimeout(timeout) self.timeout = timeout self.state = None self.next_block = 0 self.factory = TftpPacketFactory() # Note, setting the host will also set self.address, as it's a property. self.host = host self.port = port # The port associated with the TID self.tidport = None # Metrics self.metrics = TftpMetrics() # Fluag when the transfer is pending completion. self.pending_complete = False # Time when this context last received any traffic. # FIXME: does this belong in metrics? self.last_update = 0 # The last packet we sent, if applicable, to make resending easy. self.last_pkt = None # Count the number of retry attempts. self.retry_count = 0 def getBlocksize(self): """Fetch the current blocksize for this session.""" return int(self.options.get('blksize', 512)) def __del__(self): """Simple destructor to try to call housekeeping in the end method if not called explicitely. Leaking file descriptors is not a good thing.""" self.end() def checkTimeout(self, now): """Compare current time with last_update time, and raise an exception if we're over the timeout time.""" log.debug("checking for timeout on session %s", self) if now - self.last_update > self.timeout: raise TftpTimeout("Timeout waiting for traffic") def start(self): raise NotImplementedError("Abstract method") def end(self, close_fileobj=True): """Perform session cleanup, since the end method should always be called explicitely by the calling code, this works better than the destructor. Set close_fileobj to False so fileobj can be returned open.""" log.debug("in TftpContext.end - closing socket") self.sock.close() if close_fileobj and self.fileobj is not None and not self.fileobj.closed: log.debug("self.fileobj is open - closing") self.fileobj.close() def gethost(self): "Simple getter method for use in a property." return self.__host def sethost(self, host): """Setter method that also sets the address property as a result of the host that is set.""" self.__host = host self.address = socket.gethostbyname(host) host = property(gethost, sethost) def setNextBlock(self, block): if block >= 2 ** 16: log.debug("Block number rollover to 0 again") block = 0 self.__eblock = block def getNextBlock(self): return self.__eblock next_block = property(getNextBlock, setNextBlock) def cycle(self): """Here we wait for a response from the server after sending it something, and dispatch appropriate action to that response.""" try: (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) except socket.timeout: log.warning("Timeout waiting for traffic, retrying...") raise TftpTimeout("Timed-out waiting for traffic") # Ok, we've received a packet. Log it. log.debug("Received %d bytes from %s:%s", len(buffer), raddress, rport) # And update our last updated time. self.last_update = time.time() # Decode it. recvpkt = self.factory.parse(buffer) # Check for known "connection". if raddress != self.address: log.warning("Received traffic from %s, expected host %s. Discarding" % (raddress, self.host)) if self.tidport and self.tidport != rport: log.warning("Received traffic from %s:%s but we're " "connected to %s:%s. Discarding." % (raddress, rport, self.host, self.tidport)) # If there is a packethook defined, call it. We unconditionally # pass all packets, it's up to the client to screen out different # kinds of packets. This way, the client is privy to things like # negotiated options. if self.packethook: self.packethook(recvpkt) # And handle it, possibly changing state. self.state = self.state.handle(recvpkt, raddress, rport) # If we didn't throw any exceptions here, reset the retry_count to # zero. self.retry_count = 0 class TftpContextServer(TftpContext): """The context for the server.""" def __init__(self, host, port, timeout, root, dyn_file_func=None, upload_open=None): TftpContext.__init__(self, host, port, timeout, ) # At this point we have no idea if this is a download or an upload. We # need to let the start state determine that. self.state = TftpStateServerStart(self) self.root = root self.dyn_file_func = dyn_file_func self.upload_open = upload_open def __str__(self): return "%s:%s %s" % (self.host, self.port, self.state) def start(self, buffer): """Start the state cycle. Note that the server context receives an initial packet in its start method. Also note that the server does not loop on cycle(), as it expects the TftpServer object to manage that.""" log.debug("In TftpContextServer.start") self.metrics.start_time = time.time() log.debug("Set metrics.start_time to %s", self.metrics.start_time) # And update our last updated time. self.last_update = time.time() pkt = self.factory.parse(buffer) log.debug("TftpContextServer.start() - factory returned a %s", pkt) # Call handle once with the initial packet. This should put us into # the download or the upload state. self.state = self.state.handle(pkt, self.host, self.port) def end(self): """Finish up the context.""" TftpContext.end(self) self.metrics.end_time = time.time() log.debug("Set metrics.end_time to %s", self.metrics.end_time) self.metrics.compute() class TftpContextClientUpload(TftpContext): """The upload context for the client during an upload. Note: If input is a hyphen, then we will use stdin.""" def __init__(self, host, port, filename, input, options, packethook, timeout, localip = ""): TftpContext.__init__(self, host, port, timeout, localip) self.file_to_transfer = filename self.options = options self.packethook = packethook # If the input object has a read() function, # assume it is file-like. if hasattr(input, 'read'): self.fileobj = input elif input == '-': self.fileobj = sys.stdin else: self.fileobj = open(input, "rb") log.debug("TftpContextClientUpload.__init__()") log.debug("file_to_transfer = %s, options = %s" % (self.file_to_transfer, self.options)) def __str__(self): return "%s:%s %s" % (self.host, self.port, self.state) def start(self): log.info("Sending tftp upload request to %s" % self.host) log.info(" filename -> %s" % self.file_to_transfer) log.info(" options -> %s" % self.options) self.metrics.start_time = time.time() log.debug("Set metrics.start_time to %s" % self.metrics.start_time) # FIXME: put this in a sendWRQ method? pkt = TftpPacketWRQ() pkt.filename = self.file_to_transfer pkt.mode = "octet" # FIXME - shouldn't hardcode this pkt.options = self.options self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) self.next_block = 1 self.last_pkt = pkt # FIXME: should we centralize sendto operations so we can refactor all # saving of the packet to the last_pkt field? self.state = TftpStateSentWRQ(self) while self.state: try: log.debug("State is %s" % self.state) self.cycle() except TftpTimeout as err: log.error(str(err)) self.retry_count += 1 if self.retry_count >= TIMEOUT_RETRIES: log.debug("hit max retries, giving up") raise else: log.warning("resending last packet") self.state.resendLast() def end(self): """Finish up the context.""" TftpContext.end(self) self.metrics.end_time = time.time() log.debug("Set metrics.end_time to %s" % self.metrics.end_time) self.metrics.compute() class TftpContextClientDownload(TftpContext): """The download context for the client during a download. Note: If output is a hyphen, then the output will be sent to stdout.""" def __init__(self, host, port, filename, output, options, packethook, timeout, localip = ""): TftpContext.__init__(self, host, port, timeout, localip) # FIXME: should we refactor setting of these params? self.file_to_transfer = filename self.options = options self.packethook = packethook self.filelike_fileobj = False # If the output object has a write() function, # assume it is file-like. if hasattr(output, 'write'): self.fileobj = output self.filelike_fileobj = True # If the output filename is -, then use stdout elif output == '-': self.fileobj = sys.stdout self.filelike_fileobj = True else: self.fileobj = open(output, "wb") log.debug("TftpContextClientDownload.__init__()") log.debug("file_to_transfer = %s, options = %s" % (self.file_to_transfer, self.options)) def __str__(self): return "%s:%s %s" % (self.host, self.port, self.state) def start(self): """Initiate the download.""" log.info("Sending tftp download request to %s" % self.host) log.info(" filename -> %s" % self.file_to_transfer) log.info(" options -> %s" % self.options) self.metrics.start_time = time.time() log.debug("Set metrics.start_time to %s" % self.metrics.start_time) # FIXME: put this in a sendRRQ method? pkt = TftpPacketRRQ() pkt.filename = self.file_to_transfer pkt.mode = "octet" # FIXME - shouldn't hardcode this pkt.options = self.options self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) self.next_block = 1 self.last_pkt = pkt self.state = TftpStateSentRRQ(self) while self.state: try: log.debug("State is %s" % self.state) self.cycle() except TftpTimeout as err: log.error(str(err)) self.retry_count += 1 if self.retry_count >= TIMEOUT_RETRIES: log.debug("hit max retries, giving up") raise else: log.warning("resending last packet") self.state.resendLast() except TftpFileNotFoundError as err: # If we received file not found, then we should not save the open # output file or we'll be left with a size zero file. Delete it, # if it exists. log.error("Received File not found error") if self.fileobj is not None and not self.filelike_fileobj: if os.path.exists(self.fileobj.name): log.debug("unlinking output file of %s", self.fileobj.name) os.unlink(self.fileobj.name) raise def end(self): """Finish up the context.""" TftpContext.end(self, not self.filelike_fileobj) self.metrics.end_time = time.time() log.debug("Set metrics.end_time to %s" % self.metrics.end_time) self.metrics.compute()