diff options
Diffstat (limited to 'tester/rt/tftpy/TftpContexts.py')
-rw-r--r-- | tester/rt/tftpy/TftpContexts.py | 406 |
1 files changed, 406 insertions, 0 deletions
diff --git a/tester/rt/tftpy/TftpContexts.py b/tester/rt/tftpy/TftpContexts.py new file mode 100644 index 0000000..271441b --- /dev/null +++ b/tester/rt/tftpy/TftpContexts.py @@ -0,0 +1,406 @@ +"""This module implements all contexts for state handling during uploads and +downloads, the main interface to which being the TftpContext base class. + +The concept is simple. Each context object represents a single upload or +download, and the state object in the context object represents the current +state of that transfer. The state object has a handle() method that expects +the next packet in the transfer, and returns a state object until the transfer +is complete, at which point it returns None. That is, unless there is a fatal +error, in which case a TftpException is returned instead.""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from .TftpShared import * +from .TftpPacketTypes import * +from .TftpPacketFactory import TftpPacketFactory +from .TftpStates import * +import socket, time, sys + +############################################################################### +# Utility classes +############################################################################### + +class TftpMetrics(object): + """A class representing metrics of the transfer.""" + def __init__(self): + # Bytes transferred + self.bytes = 0 + # Bytes re-sent + self.resent_bytes = 0 + # Duplicate packets received + self.dups = {} + self.dupcount = 0 + # Times + self.start_time = 0 + self.end_time = 0 + self.duration = 0 + # Rates + self.bps = 0 + self.kbps = 0 + # Generic errors + self.errors = 0 + + def compute(self): + # Compute transfer time + self.duration = self.end_time - self.start_time + if self.duration == 0: + self.duration = 1 + log.debug("TftpMetrics.compute: duration is %s", self.duration) + self.bps = (self.bytes * 8.0) / self.duration + self.kbps = self.bps / 1024.0 + log.debug("TftpMetrics.compute: kbps is %s", self.kbps) + for key in self.dups: + self.dupcount += self.dups[key] + + def add_dup(self, pkt): + """This method adds a dup for a packet to the metrics.""" + log.debug("Recording a dup of %s", pkt) + s = str(pkt) + if s in self.dups: + self.dups[s] += 1 + else: + self.dups[s] = 1 + tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached") + +############################################################################### +# Context classes +############################################################################### + +class TftpContext(object): + """The base class of the contexts.""" + + def __init__(self, host, port, timeout, localip = ""): + """Constructor for the base context, setting shared instance + variables.""" + self.file_to_transfer = None + self.fileobj = None + self.options = None + self.packethook = None + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + if localip != "": + self.sock.bind((localip, 0)) + self.sock.settimeout(timeout) + self.timeout = timeout + self.state = None + self.next_block = 0 + self.factory = TftpPacketFactory() + # Note, setting the host will also set self.address, as it's a property. + self.host = host + self.port = port + # The port associated with the TID + self.tidport = None + # Metrics + self.metrics = TftpMetrics() + # Fluag when the transfer is pending completion. + self.pending_complete = False + # Time when this context last received any traffic. + # FIXME: does this belong in metrics? + self.last_update = 0 + # The last packet we sent, if applicable, to make resending easy. + self.last_pkt = None + # Count the number of retry attempts. + self.retry_count = 0 + + def getBlocksize(self): + """Fetch the current blocksize for this session.""" + return int(self.options.get('blksize', 512)) + + def __del__(self): + """Simple destructor to try to call housekeeping in the end method if + not called explicitely. Leaking file descriptors is not a good + thing.""" + self.end() + + def checkTimeout(self, now): + """Compare current time with last_update time, and raise an exception + if we're over the timeout time.""" + log.debug("checking for timeout on session %s", self) + if now - self.last_update > self.timeout: + raise TftpTimeout("Timeout waiting for traffic") + + def start(self): + raise NotImplementedError("Abstract method") + + def end(self): + """Perform session cleanup, since the end method should always be + called explicitely by the calling code, this works better than the + destructor.""" + log.debug("in TftpContext.end") + self.sock.close() + if self.fileobj is not None and not self.fileobj.closed: + log.debug("self.fileobj is open - closing") + self.fileobj.close() + + def gethost(self): + "Simple getter method for use in a property." + return self.__host + + def sethost(self, host): + """Setter method that also sets the address property as a result + of the host that is set.""" + self.__host = host + self.address = socket.gethostbyname(host) + + host = property(gethost, sethost) + + def setNextBlock(self, block): + if block >= 2 ** 16: + log.debug("Block number rollover to 0 again") + block = 0 + self.__eblock = block + + def getNextBlock(self): + return self.__eblock + + next_block = property(getNextBlock, setNextBlock) + + def cycle(self): + """Here we wait for a response from the server after sending it + something, and dispatch appropriate action to that response.""" + try: + (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) + except socket.timeout: + log.warn("Timeout waiting for traffic, retrying...") + raise TftpTimeout("Timed-out waiting for traffic") + + # Ok, we've received a packet. Log it. + log.debug("Received %d bytes from %s:%s", + len(buffer), raddress, rport) + # And update our last updated time. + self.last_update = time.time() + + # Decode it. + recvpkt = self.factory.parse(buffer) + + # Check for known "connection". + if raddress != self.address: + log.warn("Received traffic from %s, expected host %s. Discarding" + % (raddress, self.host)) + + if self.tidport and self.tidport != rport: + log.warn("Received traffic from %s:%s but we're " + "connected to %s:%s. Discarding." + % (raddress, rport, + self.host, self.tidport)) + + # If there is a packethook defined, call it. We unconditionally + # pass all packets, it's up to the client to screen out different + # kinds of packets. This way, the client is privy to things like + # negotiated options. + if self.packethook: + self.packethook(recvpkt) + + # And handle it, possibly changing state. + self.state = self.state.handle(recvpkt, raddress, rport) + # If we didn't throw any exceptions here, reset the retry_count to + # zero. + self.retry_count = 0 + +class TftpContextServer(TftpContext): + """The context for the server.""" + def __init__(self, + host, + port, + timeout, + root, + dyn_file_func=None, + upload_open=None): + TftpContext.__init__(self, + host, + port, + timeout, + ) + # At this point we have no idea if this is a download or an upload. We + # need to let the start state determine that. + self.state = TftpStateServerStart(self) + + self.root = root + self.dyn_file_func = dyn_file_func + self.upload_open = upload_open + + def __str__(self): + return "%s:%s %s" % (self.host, self.port, self.state) + + def start(self, buffer): + """Start the state cycle. Note that the server context receives an + initial packet in its start method. Also note that the server does not + loop on cycle(), as it expects the TftpServer object to manage + that.""" + log.debug("In TftpContextServer.start") + self.metrics.start_time = time.time() + log.debug("Set metrics.start_time to %s", self.metrics.start_time) + # And update our last updated time. + self.last_update = time.time() + + pkt = self.factory.parse(buffer) + log.debug("TftpContextServer.start() - factory returned a %s", pkt) + + # Call handle once with the initial packet. This should put us into + # the download or the upload state. + self.state = self.state.handle(pkt, + self.host, + self.port) + + def end(self): + """Finish up the context.""" + TftpContext.end(self) + self.metrics.end_time = time.time() + log.debug("Set metrics.end_time to %s", self.metrics.end_time) + self.metrics.compute() + +class TftpContextClientUpload(TftpContext): + """The upload context for the client during an upload. + Note: If input is a hyphen, then we will use stdin.""" + def __init__(self, + host, + port, + filename, + input, + options, + packethook, + timeout, + localip = ""): + TftpContext.__init__(self, + host, + port, + timeout, + localip) + self.file_to_transfer = filename + self.options = options + self.packethook = packethook + # If the input object has a read() function, + # assume it is file-like. + if hasattr(input, 'read'): + self.fileobj = input + elif input == '-': + self.fileobj = sys.stdin + else: + self.fileobj = open(input, "rb") + + log.debug("TftpContextClientUpload.__init__()") + log.debug("file_to_transfer = %s, options = %s" % + (self.file_to_transfer, self.options)) + + def __str__(self): + return "%s:%s %s" % (self.host, self.port, self.state) + + def start(self): + log.info("Sending tftp upload request to %s" % self.host) + log.info(" filename -> %s" % self.file_to_transfer) + log.info(" options -> %s" % self.options) + + self.metrics.start_time = time.time() + log.debug("Set metrics.start_time to %s" % self.metrics.start_time) + + # FIXME: put this in a sendWRQ method? + pkt = TftpPacketWRQ() + pkt.filename = self.file_to_transfer + pkt.mode = "octet" # FIXME - shouldn't hardcode this + pkt.options = self.options + self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) + self.next_block = 1 + self.last_pkt = pkt + # FIXME: should we centralize sendto operations so we can refactor all + # saving of the packet to the last_pkt field? + + self.state = TftpStateSentWRQ(self) + + while self.state: + try: + log.debug("State is %s" % self.state) + self.cycle() + except TftpTimeout as err: + log.error(str(err)) + self.retry_count += 1 + if self.retry_count >= TIMEOUT_RETRIES: + log.debug("hit max retries, giving up") + raise + else: + log.warn("resending last packet") + self.state.resendLast() + + def end(self): + """Finish up the context.""" + TftpContext.end(self) + self.metrics.end_time = time.time() + log.debug("Set metrics.end_time to %s" % self.metrics.end_time) + self.metrics.compute() + + +class TftpContextClientDownload(TftpContext): + """The download context for the client during a download. + Note: If output is a hyphen, then the output will be sent to stdout.""" + def __init__(self, + host, + port, + filename, + output, + options, + packethook, + timeout, + localip = ""): + TftpContext.__init__(self, + host, + port, + timeout, + localip) + # FIXME: should we refactor setting of these params? + self.file_to_transfer = filename + self.options = options + self.packethook = packethook + # If the output object has a write() function, + # assume it is file-like. + if hasattr(output, 'write'): + self.fileobj = output + # If the output filename is -, then use stdout + elif output == '-': + self.fileobj = sys.stdout + else: + self.fileobj = open(output, "wb") + + log.debug("TftpContextClientDownload.__init__()") + log.debug("file_to_transfer = %s, options = %s" % + (self.file_to_transfer, self.options)) + + def __str__(self): + return "%s:%s %s" % (self.host, self.port, self.state) + + def start(self): + """Initiate the download.""" + log.info("Sending tftp download request to %s" % self.host) + log.info(" filename -> %s" % self.file_to_transfer) + log.info(" options -> %s" % self.options) + + self.metrics.start_time = time.time() + log.debug("Set metrics.start_time to %s" % self.metrics.start_time) + + # FIXME: put this in a sendRRQ method? + pkt = TftpPacketRRQ() + pkt.filename = self.file_to_transfer + pkt.mode = "octet" # FIXME - shouldn't hardcode this + pkt.options = self.options + self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) + self.next_block = 1 + self.last_pkt = pkt + + self.state = TftpStateSentRRQ(self) + + while self.state: + try: + log.debug("State is %s" % self.state) + self.cycle() + except TftpTimeout as err: + log.error(str(err)) + self.retry_count += 1 + if self.retry_count >= TIMEOUT_RETRIES: + log.debug("hit max retries, giving up") + raise + else: + log.warn("resending last packet") + self.state.resendLast() + + def end(self): + """Finish up the context.""" + TftpContext.end(self) + self.metrics.end_time = time.time() + log.debug("Set metrics.end_time to %s" % self.metrics.end_time) + self.metrics.compute() |