diff options
Diffstat (limited to 'tester/rt/tftpy/TftpStates.py')
-rw-r--r-- | tester/rt/tftpy/TftpStates.py | 598 |
1 files changed, 598 insertions, 0 deletions
diff --git a/tester/rt/tftpy/TftpStates.py b/tester/rt/tftpy/TftpStates.py new file mode 100644 index 0000000..801e970 --- /dev/null +++ b/tester/rt/tftpy/TftpStates.py @@ -0,0 +1,598 @@ +"""This module implements all state handling during uploads and downloads, the +main interface to which being the TftpState base class. + +The concept is simple. Each context object represents a single upload or +download, and the state object in the context object represents the current +state of that transfer. The state object has a handle() method that expects +the next packet in the transfer, and returns a state object until the transfer +is complete, at which point it returns None. That is, unless there is a fatal +error, in which case a TftpException is returned instead.""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from .TftpShared import * +from .TftpPacketTypes import * +import os + +############################################################################### +# State classes +############################################################################### + +class TftpState(object): + """The base class for the states.""" + + def __init__(self, context): + """Constructor for setting up common instance variables. The involved + file object is required, since in tftp there's always a file + involved.""" + self.context = context + + def handle(self, pkt, raddress, rport): + """An abstract method for handling a packet. It is expected to return + a TftpState object, either itself or a new state.""" + raise NotImplementedError("Abstract method") + + def handleOACK(self, pkt): + """This method handles an OACK from the server, syncing any accepted + options.""" + if pkt.options.keys() > 0: + if pkt.match_options(self.context.options): + log.info("Successful negotiation of options") + # Set options to OACK options + self.context.options = pkt.options + for key in self.context.options: + log.info(" %s = %s" % (key, self.context.options[key])) + else: + log.error("Failed to negotiate options") + raise TftpException("Failed to negotiate options") + else: + raise TftpException("No options found in OACK") + + def returnSupportedOptions(self, options): + """This method takes a requested options list from a client, and + returns the ones that are supported.""" + # We support the options blksize and tsize right now. + # FIXME - put this somewhere else? + accepted_options = {} + for option in options: + if option == 'blksize': + # Make sure it's valid. + if int(options[option]) > MAX_BLKSIZE: + log.info("Client requested blksize greater than %d " + "setting to maximum" % MAX_BLKSIZE) + accepted_options[option] = MAX_BLKSIZE + elif int(options[option]) < MIN_BLKSIZE: + log.info("Client requested blksize less than %d " + "setting to minimum" % MIN_BLKSIZE) + accepted_options[option] = MIN_BLKSIZE + else: + accepted_options[option] = options[option] + elif option == 'tsize': + log.debug("tsize option is set") + accepted_options['tsize'] = 0 + else: + log.info("Dropping unsupported option '%s'" % option) + log.debug("Returning these accepted options: %s", accepted_options) + return accepted_options + + def sendDAT(self): + """This method sends the next DAT packet based on the data in the + context. It returns a boolean indicating whether the transfer is + finished.""" + finished = False + blocknumber = self.context.next_block + # Test hook + if DELAY_BLOCK and DELAY_BLOCK == blocknumber: + import time + log.debug("Deliberately delaying 10 seconds...") + time.sleep(10) + dat = None + blksize = self.context.getBlocksize() + buffer = self.context.fileobj.read(blksize) + log.debug("Read %d bytes into buffer", len(buffer)) + if len(buffer) < blksize: + log.info("Reached EOF on file %s" + % self.context.file_to_transfer) + finished = True + dat = TftpPacketDAT() + dat.data = buffer + dat.blocknumber = blocknumber + self.context.metrics.bytes += len(dat.data) + log.debug("Sending DAT packet %d", dat.blocknumber) + self.context.sock.sendto(dat.encode().buffer, + (self.context.host, self.context.tidport)) + if self.context.packethook: + self.context.packethook(dat) + self.context.last_pkt = dat + return finished + + def sendACK(self, blocknumber=None): + """This method sends an ack packet to the block number specified. If + none is specified, it defaults to the next_block property in the + parent context.""" + log.debug("In sendACK, passed blocknumber is %s", blocknumber) + if blocknumber is None: + blocknumber = self.context.next_block + log.info("Sending ack to block %d" % blocknumber) + ackpkt = TftpPacketACK() + ackpkt.blocknumber = blocknumber + self.context.sock.sendto(ackpkt.encode().buffer, + (self.context.host, + self.context.tidport)) + self.context.last_pkt = ackpkt + + def sendError(self, errorcode): + """This method uses the socket passed, and uses the errorcode to + compose and send an error packet.""" + log.debug("In sendError, being asked to send error %d", errorcode) + errpkt = TftpPacketERR() + errpkt.errorcode = errorcode + self.context.sock.sendto(errpkt.encode().buffer, + (self.context.host, + self.context.tidport)) + self.context.last_pkt = errpkt + + def sendOACK(self): + """This method sends an OACK packet with the options from the current + context.""" + log.debug("In sendOACK with options %s", self.context.options) + pkt = TftpPacketOACK() + pkt.options = self.context.options + self.context.sock.sendto(pkt.encode().buffer, + (self.context.host, + self.context.tidport)) + self.context.last_pkt = pkt + + def resendLast(self): + "Resend the last sent packet due to a timeout." + log.warn("Resending packet %s on sessions %s" + % (self.context.last_pkt, self)) + self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer) + self.context.metrics.add_dup(self.context.last_pkt) + sendto_port = self.context.tidport + if not sendto_port: + # If the tidport wasn't set, then the remote end hasn't even + # started talking to us yet. That's not good. Maybe it's not + # there. + sendto_port = self.context.port + self.context.sock.sendto(self.context.last_pkt.encode().buffer, + (self.context.host, sendto_port)) + if self.context.packethook: + self.context.packethook(self.context.last_pkt) + + def handleDat(self, pkt): + """This method handles a DAT packet during a client download, or a + server upload.""" + log.info("Handling DAT packet - block %d" % pkt.blocknumber) + log.debug("Expecting block %s", self.context.next_block) + if pkt.blocknumber == self.context.next_block: + log.debug("Good, received block %d in sequence", pkt.blocknumber) + + self.sendACK() + self.context.next_block += 1 + + log.debug("Writing %d bytes to output file", len(pkt.data)) + self.context.fileobj.write(pkt.data) + self.context.metrics.bytes += len(pkt.data) + # Check for end-of-file, any less than full data packet. + if len(pkt.data) < self.context.getBlocksize(): + log.info("End of file detected") + return None + + elif pkt.blocknumber < self.context.next_block: + if pkt.blocknumber == 0: + log.warn("There is no block zero!") + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("There is no block zero!") + log.warn("Dropping duplicate block %d" % pkt.blocknumber) + self.context.metrics.add_dup(pkt) + log.debug("ACKing block %d again, just in case", pkt.blocknumber) + self.sendACK(pkt.blocknumber) + + else: + # FIXME: should we be more tolerant and just discard instead? + msg = "Whoa! Received future block %d but expected %d" \ + % (pkt.blocknumber, self.context.next_block) + log.error(msg) + raise TftpException(msg) + + # Default is to ack + return TftpStateExpectDAT(self.context) + +class TftpServerState(TftpState): + """The base class for server states.""" + + def __init__(self, context): + TftpState.__init__(self, context) + + # This variable is used to store the absolute path to the file being + # managed. + self.full_path = None + + def serverInitial(self, pkt, raddress, rport): + """This method performs initial setup for a server context transfer, + put here to refactor code out of the TftpStateServerRecvRRQ and + TftpStateServerRecvWRQ classes, since their initial setup is + identical. The method returns a boolean, sendoack, to indicate whether + it is required to send an OACK to the client.""" + options = pkt.options + sendoack = False + if not self.context.tidport: + self.context.tidport = rport + log.info("Setting tidport to %s" % rport) + + log.debug("Setting default options, blksize") + self.context.options = { 'blksize': DEF_BLKSIZE } + + if options: + log.debug("Options requested: %s", options) + supported_options = self.returnSupportedOptions(options) + self.context.options.update(supported_options) + sendoack = True + + # FIXME - only octet mode is supported at this time. + if pkt.mode != 'octet': + #self.sendError(TftpErrors.IllegalTftpOp) + #raise TftpException("Only octet transfers are supported at this time.") + log.warning("Received non-octet mode request. I'll reply with binary data.") + + # test host/port of client end + if self.context.host != raddress or self.context.port != rport: + self.sendError(TftpErrors.UnknownTID) + log.error("Expected traffic from %s:%s but received it " + "from %s:%s instead." + % (self.context.host, + self.context.port, + raddress, + rport)) + # FIXME: increment an error count? + # Return same state, we're still waiting for valid traffic. + return self + + log.debug("Requested filename is %s", pkt.filename) + + # Build the filename on this server and ensure it is contained + # in the specified root directory. + # + # Filenames that begin with server root are accepted. It's + # assumed the client and server are tightly connected and this + # provides backwards compatibility. + # + # Filenames otherwise are relative to the server root. If they + # begin with a '/' strip it off as otherwise os.path.join will + # treat it as absolute (regardless of whether it is ntpath or + # posixpath module + if pkt.filename.startswith(self.context.root.encode()): + full_path = pkt.filename + else: + full_path = os.path.join(self.context.root, pkt.filename.decode().lstrip('/')) + + # Use abspath to eliminate any remaining relative elements + # (e.g. '..') and ensure that is still within the server's + # root directory + self.full_path = os.path.abspath(full_path) + log.debug("full_path is %s", full_path) + if self.full_path.startswith(self.context.root): + log.info("requested file is in the server root - good") + else: + log.warn("requested file is not within the server root - bad") + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("bad file path") + + self.context.file_to_transfer = pkt.filename + + return sendoack + + +class TftpStateServerRecvRRQ(TftpServerState): + """This class represents the state of the TFTP server when it has just + received an RRQ packet.""" + def handle(self, pkt, raddress, rport): + "Handle an initial RRQ packet as a server." + log.debug("In TftpStateServerRecvRRQ.handle") + sendoack = self.serverInitial(pkt, raddress, rport) + path = self.full_path + log.info("Opening file %s for reading" % path) + if os.path.exists(path): + # Note: Open in binary mode for win32 portability, since win32 + # blows. + self.context.fileobj = open(path, "rb") + elif self.context.dyn_file_func: + log.debug("No such file %s but using dyn_file_func", path) + self.context.fileobj = \ + self.context.dyn_file_func(self.context.file_to_transfer, raddress=raddress, rport=rport) + + if self.context.fileobj is None: + log.debug("dyn_file_func returned 'None', treating as " + "FileNotFound") + self.sendError(TftpErrors.FileNotFound) + raise TftpException("File not found: %s" % path) + else: + self.sendError(TftpErrors.FileNotFound) + raise TftpException("File not found: %s" % path) + + # Options negotiation. + if sendoack and self.context.options.has_key('tsize'): + # getting the file size for the tsize option. As we handle + # file-like objects and not only real files, we use this seeking + # method instead of asking the OS + self.context.fileobj.seek(0, os.SEEK_END) + tsize = str(self.context.fileobj.tell()) + self.context.fileobj.seek(0, 0) + self.context.options['tsize'] = tsize + + if sendoack: + # Note, next_block is 0 here since that's the proper + # acknowledgement to an OACK. + # FIXME: perhaps we do need a TftpStateExpectOACK class... + self.sendOACK() + # Note, self.context.next_block is already 0. + else: + self.context.next_block = 1 + log.debug("No requested options, starting send...") + self.context.pending_complete = self.sendDAT() + # Note, we expect an ack regardless of whether we sent a DAT or an + # OACK. + return TftpStateExpectACK(self.context) + + # Note, we don't have to check any other states in this method, that's + # up to the caller. + +class TftpStateServerRecvWRQ(TftpServerState): + """This class represents the state of the TFTP server when it has just + received a WRQ packet.""" + def make_subdirs(self): + """The purpose of this method is to, if necessary, create all of the + subdirectories leading up to the file to the written.""" + # Pull off everything below the root. + subpath = self.full_path[len(self.context.root):] + log.debug("make_subdirs: subpath is %s", subpath) + # Split on directory separators, but drop the last one, as it should + # be the filename. + dirs = subpath.split(os.sep)[:-1] + log.debug("dirs is %s", dirs) + current = self.context.root + for dir in dirs: + if dir: + current = os.path.join(current, dir) + if os.path.isdir(current): + log.debug("%s is already an existing directory", current) + else: + os.mkdir(current, 0o700) + + def handle(self, pkt, raddress, rport): + "Handle an initial WRQ packet as a server." + log.debug("In TftpStateServerRecvWRQ.handle") + sendoack = self.serverInitial(pkt, raddress, rport) + path = self.full_path + if self.context.upload_open: + f = self.context.upload_open(path, self.context) + if f is None: + self.sendError(TftpErrors.AccessViolation) + raise TftpException, "Dynamic path %s not permitted" % path + else: + self.context.fileobj = f + else: + log.info("Opening file %s for writing" % path) + if os.path.exists(path): + # FIXME: correct behavior? + log.warn("File %s exists already, overwriting..." % ( + self.context.file_to_transfer)) + # FIXME: I think we should upload to a temp file and not overwrite + # the existing file until the file is successfully uploaded. + self.make_subdirs() + self.context.fileobj = open(path, "wb") + + # Options negotiation. + if sendoack: + log.debug("Sending OACK to client") + self.sendOACK() + else: + log.debug("No requested options, expecting transfer to begin...") + self.sendACK() + # Whether we're sending an oack or not, we're expecting a DAT for + # block 1 + self.context.next_block = 1 + # We may have sent an OACK, but we're expecting a DAT as the response + # to either the OACK or an ACK, so lets unconditionally use the + # TftpStateExpectDAT state. + return TftpStateExpectDAT(self.context) + + # Note, we don't have to check any other states in this method, that's + # up to the caller. + +class TftpStateServerStart(TftpState): + """The start state for the server. This is a transitory state since at + this point we don't know if we're handling an upload or a download. We + will commit to one of them once we interpret the initial packet.""" + def handle(self, pkt, raddress, rport): + """Handle a packet we just received.""" + log.debug("In TftpStateServerStart.handle") + if isinstance(pkt, TftpPacketRRQ): + log.debug("Handling an RRQ packet") + return TftpStateServerRecvRRQ(self.context).handle(pkt, + raddress, + rport) + elif isinstance(pkt, TftpPacketWRQ): + log.debug("Handling a WRQ packet") + return TftpStateServerRecvWRQ(self.context).handle(pkt, + raddress, + rport) + else: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Invalid packet to begin up/download: %s" % pkt) + +class TftpStateExpectACK(TftpState): + """This class represents the state of the transfer when a DAT was just + sent, and we are waiting for an ACK from the server. This class is the + same one used by the client during the upload, and the server during the + download.""" + def handle(self, pkt, raddress, rport): + "Handle a packet, hopefully an ACK since we just sent a DAT." + if isinstance(pkt, TftpPacketACK): + log.debug("Received ACK for packet %d" % pkt.blocknumber) + # Is this an ack to the one we just sent? + if self.context.next_block == pkt.blocknumber: + if self.context.pending_complete: + log.info("Received ACK to final DAT, we're done.") + return None + else: + log.debug("Good ACK, sending next DAT") + self.context.next_block += 1 + log.debug("Incremented next_block to %d", + self.context.next_block) + self.context.pending_complete = self.sendDAT() + + elif pkt.blocknumber < self.context.next_block: + log.warn("Received duplicate ACK for block %d" + % pkt.blocknumber) + self.context.metrics.add_dup(pkt) + + else: + log.warn("Oooh, time warp. Received ACK to packet we " + "didn't send yet. Discarding.") + self.context.metrics.errors += 1 + return self + elif isinstance(pkt, TftpPacketERR): + log.error("Received ERR packet from peer: %s" % str(pkt)) + raise TftpException("Received ERR packet from peer: %s" % str(pkt)) + else: + log.warn("Discarding unsupported packet: %s" % str(pkt)) + return self + +class TftpStateExpectDAT(TftpState): + """Just sent an ACK packet. Waiting for DAT.""" + def handle(self, pkt, raddress, rport): + """Handle the packet in response to an ACK, which should be a DAT.""" + if isinstance(pkt, TftpPacketDAT): + return self.handleDat(pkt) + + # Every other packet type is a problem. + elif isinstance(pkt, TftpPacketACK): + # Umm, we ACK, you don't. + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received ACK from peer when expecting DAT") + + elif isinstance(pkt, TftpPacketWRQ): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received WRQ from peer when expecting DAT") + + elif isinstance(pkt, TftpPacketERR): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received ERR from peer: " + str(pkt)) + + else: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received unknown packet type from peer: " + str(pkt)) + +class TftpStateSentWRQ(TftpState): + """Just sent an WRQ packet for an upload.""" + def handle(self, pkt, raddress, rport): + """Handle a packet we just received.""" + if not self.context.tidport: + self.context.tidport = rport + log.debug("Set remote port for session to %s", rport) + + # If we're going to successfully transfer the file, then we should see + # either an OACK for accepted options, or an ACK to ignore options. + if isinstance(pkt, TftpPacketOACK): + log.info("Received OACK from server") + try: + self.handleOACK(pkt) + except TftpException: + log.error("Failed to negotiate options") + self.sendError(TftpErrors.FailedNegotiation) + raise + else: + log.debug("Sending first DAT packet") + self.context.pending_complete = self.sendDAT() + log.debug("Changing state to TftpStateExpectACK") + return TftpStateExpectACK(self.context) + + elif isinstance(pkt, TftpPacketACK): + log.info("Received ACK from server") + log.debug("Apparently the server ignored our options") + # The block number should be zero. + if pkt.blocknumber == 0: + log.debug("Ack blocknumber is zero as expected") + log.debug("Sending first DAT packet") + self.context.pending_complete = self.sendDAT() + log.debug("Changing state to TftpStateExpectACK") + return TftpStateExpectACK(self.context) + else: + log.warn("Discarding ACK to block %s" % pkt.blocknumber) + log.debug("Still waiting for valid response from server") + return self + + elif isinstance(pkt, TftpPacketERR): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received ERR from server: %s" % pkt) + + elif isinstance(pkt, TftpPacketRRQ): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received RRQ from server while in upload") + + elif isinstance(pkt, TftpPacketDAT): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received DAT from server while in upload") + + else: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received unknown packet type from server: %s" % pkt) + + # By default, no state change. + return self + +class TftpStateSentRRQ(TftpState): + """Just sent an RRQ packet.""" + def handle(self, pkt, raddress, rport): + """Handle the packet in response to an RRQ to the server.""" + if not self.context.tidport: + self.context.tidport = rport + log.info("Set remote port for session to %s" % rport) + + # Now check the packet type and dispatch it properly. + if isinstance(pkt, TftpPacketOACK): + log.info("Received OACK from server") + try: + self.handleOACK(pkt) + except TftpException as err: + log.error("Failed to negotiate options: %s" % str(err)) + self.sendError(TftpErrors.FailedNegotiation) + raise + else: + log.debug("Sending ACK to OACK") + + self.sendACK(blocknumber=0) + + log.debug("Changing state to TftpStateExpectDAT") + return TftpStateExpectDAT(self.context) + + elif isinstance(pkt, TftpPacketDAT): + # If there are any options set, then the server didn't honour any + # of them. + log.info("Received DAT from server") + if self.context.options: + log.info("Server ignored options, falling back to defaults") + self.context.options = { 'blksize': DEF_BLKSIZE } + return self.handleDat(pkt) + + # Every other packet type is a problem. + elif isinstance(pkt, TftpPacketACK): + # Umm, we ACK, the server doesn't. + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received ACK from server while in download") + + elif isinstance(pkt, TftpPacketWRQ): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received WRQ from server while in download") + + elif isinstance(pkt, TftpPacketERR): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received ERR from server: %s" % pkt) + + else: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("Received unknown packet type from server: %s" % pkt) + + # By default, no state change. + return self |