diff options
Diffstat (limited to 'tester/rt/tftpy/TftpPacketTypes.py')
-rw-r--r-- | tester/rt/tftpy/TftpPacketTypes.py | 217 |
1 files changed, 125 insertions, 92 deletions
diff --git a/tester/rt/tftpy/TftpPacketTypes.py b/tester/rt/tftpy/TftpPacketTypes.py index e45bb02..3d3bdf8 100644 --- a/tester/rt/tftpy/TftpPacketTypes.py +++ b/tester/rt/tftpy/TftpPacketTypes.py @@ -1,11 +1,16 @@ +# vim: ts=4 sw=4 et ai: +# -*- coding: utf8 -*- """This module implements the packet types of TFTP itself, and the corresponding encode and decode methods for them.""" -from __future__ import absolute_import, division, print_function, unicode_literals + import struct import sys +import logging from .TftpShared import * +log = logging.getLogger('tftpy.TftpPacketTypes') + class TftpSession(object): """This class is the base class for the tftp client and server. Any shared code should be in this class.""" @@ -20,17 +25,23 @@ class TftpPacketWithOptions(object): def __init__(self): self.options = {} + # Always use unicode strings, except at the encode/decode barrier. + # Simpler to keep things clear. def setoptions(self, options): log.debug("in TftpPacketWithOptions.setoptions") - log.debug("options: %s" % options) + log.debug("options: %s", options) myoptions = {} for key in options: - newkey = str(key) - myoptions[newkey] = str(options[key]) - log.debug("populated myoptions with %s = %s" - % (newkey, myoptions[newkey])) - - log.debug("setting options hash to: %s" % myoptions) + newkey = key + if isinstance(key, bytes): + newkey = newkey.decode('ascii') + newval = options[key] + if isinstance(newval, bytes): + newval = newval.decode('ascii') + myoptions[newkey] = newval + log.debug("populated myoptions with %s = %s", newkey, myoptions[newkey]) + + log.debug("setting options hash to: %s", myoptions) self._options = myoptions def getoptions(self): @@ -46,11 +57,11 @@ class TftpPacketWithOptions(object): """This method decodes the section of the buffer that contains an unknown number of options. It returns a dictionary of option names and values.""" - format = "!" + fmt = b"!" options = {} - log.debug("decode_options: buffer is: %s" % repr(buffer)) - log.debug("size of buffer is %d bytes" % len(buffer)) + log.debug("decode_options: buffer is: %s", repr(buffer)) + log.debug("size of buffer is %d bytes", len(buffer)) if len(buffer) == 0: log.debug("size of buffer is zero, returning empty hash") return {} @@ -58,25 +69,28 @@ class TftpPacketWithOptions(object): # Count the nulls in the buffer. Each one terminates a string. log.debug("about to iterate options buffer counting nulls") length = 0 - for c in buffer: - if ord(c) == 0: - log.debug("found a null at length %d" % length) + for i in range(len(buffer)): + if ord(buffer[i:i+1]) == 0: + log.debug("found a null at length %d", length) if length > 0: - format += "%dsx" % length + fmt += b"%dsx" % length length = -1 else: raise TftpException("Invalid options in buffer") length += 1 - log.debug("about to unpack, format is: %s" % format) - mystruct = struct.unpack(format, buffer) + log.debug("about to unpack, fmt is: %s", fmt) + mystruct = struct.unpack(fmt, buffer) tftpassert(len(mystruct) % 2 == 0, "packet with odd number of option/value pairs") for i in range(0, len(mystruct), 2): - log.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1])) - options[mystruct[i]] = mystruct[i+1] + key = mystruct[i].decode('ascii') + val = mystruct[i+1].decode('ascii') + log.debug("setting option %s to %s", key, val) + log.debug("types are %s and %s", type(key), type(val)) + options[key] = val return options @@ -120,46 +134,59 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): tftpassert(self.filename, "filename required in initial packet") tftpassert(self.mode, "mode required in initial packet") # Make sure filename and mode are bytestrings. - self.filename = self.filename.encode('ascii') - self.mode = self.mode.encode('ascii') + filename = self.filename + mode = self.mode + if not isinstance(filename, bytes): + filename = filename.encode('ascii') + if not isinstance(self.mode, bytes): + mode = mode.encode('ascii') ptype = None if self.opcode == 1: ptype = "RRQ" else: ptype = "WRQ" - log.debug("Encoding %s packet, filename = %s, mode = %s" - % (ptype, self.filename, self.mode)) + log.debug("Encoding %s packet, filename = %s, mode = %s", + ptype, filename, mode) for key in self.options: - log.debug(" Option %s = %s" % (key, self.options[key])) + log.debug(" Option %s = %s", key, self.options[key]) - format = b"!H" - format += b"%dsx" % len(self.filename) - if self.mode == b"octet": - format += b"5sx" + fmt = b"!H" + fmt += b"%dsx" % len(filename) + if mode == b"octet": + fmt += b"5sx" else: - raise AssertionError("Unsupported mode: %s" % self.mode) - # Add options. + raise AssertionError("Unsupported mode: %s" % mode) + # Add options. Note that the options list must be bytes. options_list = [] - if len(self.options.keys()) > 0: + if len(list(self.options.keys())) > 0: log.debug("there are options to encode") for key in self.options: # Populate the option name - format += b"%dsx" % len(key) - options_list.append(key.encode('ascii')) + name = key + if not isinstance(name, bytes): + name = name.encode('ascii') + options_list.append(name) + fmt += b"%dsx" % len(name) # Populate the option value - format += b"%dsx" % len(self.options[key].encode('ascii')) - options_list.append(self.options[key].encode('ascii')) - - log.debug("format is %s" % format) - log.debug("options_list is %s" % options_list) - log.debug("size of struct is %d" % struct.calcsize(format)) - - self.buffer = struct.pack(format, + value = self.options[key] + # Work with all strings. + if isinstance(value, int): + value = str(value) + if not isinstance(value, bytes): + value = value.encode('ascii') + options_list.append(value) + fmt += b"%dsx" % len(value) + + log.debug("fmt is %s", fmt) + log.debug("options_list is %s", options_list) + log.debug("size of struct is %d", struct.calcsize(fmt)) + + self.buffer = struct.pack(fmt, self.opcode, - self.filename, - self.mode, + filename, + mode, *options_list) - log.debug("buffer is %s" % repr(self.buffer)) + log.debug("buffer is %s", repr(self.buffer)) return self def decode(self): @@ -167,18 +194,15 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): # FIXME - this shares a lot of code with decode_options nulls = 0 - format = "" + fmt = b"" nulls = length = tlength = 0 log.debug("in decode: about to iterate buffer counting nulls") subbuf = self.buffer[2:] - for c in subbuf: - if sys.version_info[0] <= 2: - c = ord(c) - if c == 0: + for i in range(len(subbuf)): + if ord(subbuf[i:i+1]) == 0: nulls += 1 - log.debug("found a null at length %d, now have %d" - % (length, nulls)) - format += "%dsx" % length + log.debug("found a null at length %d, now have %d", length, nulls) + fmt += b"%dsx" % length length = -1 # At 2 nulls, we want to mark that position for decoding. if nulls == 2: @@ -186,21 +210,22 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): length += 1 tlength += 1 - log.debug("hopefully found end of mode at length %d" % tlength) + log.debug("hopefully found end of mode at length %d", tlength) # length should now be the end of the mode. tftpassert(nulls == 2, "malformed packet") shortbuf = subbuf[:tlength+1] - log.debug("about to unpack buffer with format: %s" % format) - log.debug("unpacking buffer: " + repr(shortbuf)) - mystruct = struct.unpack(format, shortbuf) + log.debug("about to unpack buffer with fmt: %s", fmt) + log.debug("unpacking buffer: %s", repr(shortbuf)) + mystruct = struct.unpack(fmt, shortbuf) tftpassert(len(mystruct) == 2, "malformed packet") - self.filename = mystruct[0] - self.mode = mystruct[1].lower() # force lc - bug 17 - log.debug("set filename to %s" % self.filename) - log.debug("set mode to %s" % self.mode) + self.filename = mystruct[0].decode('ascii') + self.mode = mystruct[1].decode('ascii').lower() # force lc - bug 17 + log.debug("set filename to %s", self.filename) + log.debug("set mode to %s", self.mode) self.options = self.decode_options(subbuf[tlength+1:]) + log.debug("options dict is now %s", self.options) return self class TftpPacketRRQ(TftpPacketInitial): @@ -269,11 +294,14 @@ class TftpPacketDAT(TftpPacket): returns self for easy method chaining.""" if len(self.data) == 0: log.debug("Encoding an empty DAT packet") - format = "!HH%ds" % len(self.data) - self.buffer = struct.pack(format, + data = self.data + if not isinstance(self.data, bytes): + data = self.data.encode('ascii') + fmt = b"!HH%ds" % len(data) + self.buffer = struct.pack(fmt, self.opcode, self.blocknumber, - self.data) + data) return self def decode(self): @@ -281,14 +309,12 @@ class TftpPacketDAT(TftpPacket): easy method chaining.""" # We know the first 2 bytes are the opcode. The second two are the # block number. - (self.blocknumber,) = struct.unpack("!H", self.buffer[2:4]) - log.debug("decoding DAT packet, block number %d" % self.blocknumber) - log.debug("should be %d bytes in the packet total" - % len(self.buffer)) + (self.blocknumber,) = struct.unpack(str("!H"), self.buffer[2:4]) + log.debug("decoding DAT packet, block number %d", self.blocknumber) + log.debug("should be %d bytes in the packet total", len(self.buffer)) # Everything else is data. self.data = self.buffer[4:] - log.debug("found %d bytes of data" - % len(self.data)) + log.debug("found %d bytes of data", len(self.data)) return self class TftpPacketACK(TftpPacket): @@ -309,9 +335,9 @@ class TftpPacketACK(TftpPacket): return 'ACK packet: block %d' % self.blocknumber def encode(self): - log.debug("encoding ACK: opcode = %d, block = %d" - % (self.opcode, self.blocknumber)) - self.buffer = struct.pack("!HH", self.opcode, self.blocknumber) + log.debug("encoding ACK: opcode = %d, block = %d", + self.opcode, self.blocknumber) + self.buffer = struct.pack(str("!HH"), self.opcode, self.blocknumber) return self def decode(self): @@ -319,9 +345,9 @@ class TftpPacketACK(TftpPacket): log.debug("detected TFTP ACK but request is too large, will truncate") log.debug("buffer was: %s", repr(self.buffer)) self.buffer = self.buffer[0:4] - self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer) - log.debug("decoded ACK packet: opcode = %d, block = %d" - % (self.opcode, self.blocknumber)) + self.opcode, self.blocknumber = struct.unpack(str("!HH"), self.buffer) + log.debug("decoded ACK packet: opcode = %d, block = %d", + self.opcode, self.blocknumber) return self class TftpPacketERR(TftpPacket): @@ -373,9 +399,9 @@ class TftpPacketERR(TftpPacket): def encode(self): """Encode the DAT packet based on instance variables, populating self.buffer, returning self.""" - format = "!HH%dsx" % len(self.errmsgs[self.errorcode]) - log.debug("encoding ERR packet with format %s" % format) - self.buffer = struct.pack(format, + fmt = b"!HH%dsx" % len(self.errmsgs[self.errorcode]) + log.debug("encoding ERR packet with fmt %s", fmt) + self.buffer = struct.pack(fmt, self.opcode, self.errorcode, self.errmsgs[self.errorcode]) @@ -385,18 +411,18 @@ class TftpPacketERR(TftpPacket): "Decode self.buffer, populating instance variables and return self." buflen = len(self.buffer) tftpassert(buflen >= 4, "malformed ERR packet, too short") - log.debug("Decoding ERR packet, length %s bytes" % buflen) + log.debug("Decoding ERR packet, length %s bytes", buflen) if buflen == 4: log.debug("Allowing this affront to the RFC of a 4-byte packet") - format = "!HH" - log.debug("Decoding ERR packet with format: %s" % format) - self.opcode, self.errorcode = struct.unpack(format, + fmt = b"!HH" + log.debug("Decoding ERR packet with fmt: %s", fmt) + self.opcode, self.errorcode = struct.unpack(fmt, self.buffer) else: log.debug("Good ERR packet > 4 bytes") - format = "!HH%dsx" % (len(self.buffer) - 5) - log.debug("Decoding ERR packet with format: %s" % format) - self.opcode, self.errorcode, self.errmsg = struct.unpack(format, + fmt = b"!HH%dsx" % (len(self.buffer) - 5) + log.debug("Decoding ERR packet with fmt: %s", fmt) + self.opcode, self.errorcode, self.errmsg = struct.unpack(fmt, self.buffer) log.error("ERR packet - errorcode: %d, message: %s" % (self.errorcode, self.errmsg)) @@ -419,17 +445,24 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions): return 'OACK packet:\n options = %s' % self.options def encode(self): - format = "!H" # opcode + fmt = b"!H" # opcode options_list = [] log.debug("in TftpPacketOACK.encode") for key in self.options: - log.debug("looping on option key %s" % key) - log.debug("value is %s" % self.options[key]) - format += "%dsx" % len(key) - format += "%dsx" % len(self.options[key]) + value = self.options[key] + if isinstance(value, int): + value = str(value) + if not isinstance(key, bytes): + key = key.encode('ascii') + if not isinstance(value, bytes): + value = value.encode('ascii') + log.debug("looping on option key %s", key) + log.debug("value is %s", value) + fmt += b"%dsx" % len(key) + fmt += b"%dsx" % len(value) options_list.append(key) - options_list.append(self.options[key]) - self.buffer = struct.pack(format, self.opcode, *options_list) + options_list.append(value) + self.buffer = struct.pack(fmt, self.opcode, *options_list) return self def decode(self): |