"SfR Fresh" - the SfR Freeware/Shareware Archive

Member "BitTorrent-4.26.0/BitTorrent/Connector.py" of archive BitTorrent-4.26.0.tar.gz:


# The contents of this file are subject to the BitTorrent Open Source License
# Version 1.1 (the License).  You may not copy or use this file, in either
# source code or executable form, except in compliance with the License.  You
# may obtain a copy of the License at http://www.bittorrent.com/license/.
#
# Software distributed under the License is distributed on an AS IS basis,
# WITHOUT WARRANTY OF ANY KIND, either express or implied.  See the License
# for the specific language governing rights and limitations under the
# License.

# Originally written by Bram Cohen, heavily modified by Uoti Urpala
# Fast extensions added by David Harrison

from __future__ import generators

# DEBUG
# If you think FAST_EXTENSION is causing problems then set the following:
disable_fast_extension = False
# END DEBUG

noisy = False
log_data = False 

# for crypto
from random import randrange
from BTL.hash import sha
from Crypto.Cipher import ARC4
# urandom comes from obsoletepythonsupport

import struct
from struct import pack, unpack
from cStringIO import StringIO

from BTL.bencode import bencode, bdecode
from BitTorrent.RawServer_twisted import Handler
from BTL.bitfield import Bitfield
from BTL import IPTools
from BTL.obsoletepythonsupport import *
from BitTorrent.ClientIdentifier import identify_client
import logging

def toint(s):
    return struct.unpack("!i", s)[0]

def tobinary(i):
    return struct.pack("!i", i)


class BTMessages(object):

    def __init__(self, messages):
        self.message_to_chr = {}
        self.chr_to_message = {}
        for o, v in messages.iteritems():
            c = chr(o)
            self.chr_to_message[c] = v
            self.message_to_chr[v] = c

    def __getitem__(self, key):
        return self.chr_to_message.get(key, "UNKNOWN: %r" % key)
        
message_dict = BTMessages({
0:'CHOKE',
1:'UNCHOKE',
2:'INTERESTED',
3:'NOT_INTERESTED',

4:'HAVE',
# index, bitfield
5:'BITFIELD',
# index, begin, length
6:'REQUEST',
# index, begin, piece
7:'PIECE',
# index, begin, piece
8:'CANCEL',

# 2-byte port message
9:'PORT',

# no args
10:'WANT_METAINFO',
11:'METAINFO',

# index
12:'SUSPECT_PIECE',

# no args
13:'SUGGEST_PIECE', # FAST_EXTENSION
14:'HAVE_ALL', # FAST_EXTENSION
15:'HAVE_NONE', # FAST_EXTENSION

# index, begin, length
16:'REJECT_REQUEST', # FAST_EXTENSION

# index
17:'ALLOWED_FAST', # FAST_EXTENSION

# compact_addr
18:'HOLE_PUNCH', # NAT_TRAVERSAL

# message id, bencoded payload
20:'UTORRENT_MSG', # UTORRENT
})

# put all the message identifiers in the module
locals().update(message_dict.message_to_chr)

# I am not even shitting you.
AZUREUS_SUCKS = CHOKE

UTORRENT_MSG_INFO = chr(0)
UTORRENT_MSG_PEX = chr(1)
                          
# reserved flags:
#  reserved[0]
#   0x80 Azureus Messaging Protocol
AZUREUS = 0x80
#  reserved[5]
#   0x10 uTorrent extensions: peer exchange, encrypted connections,
#       broadcast listen port.
UTORRENT = 0x10
#  reserved[7]
DHT = 0x01
FAST_EXTENSION = 0x04   # suggest, haveall, havenone, reject request,
                        # and allow fast extensions.
NAT_TRAVERSAL = 0x08 # holepunch                        

LAST_BYTE = DHT
if not disable_fast_extension:
    LAST_BYTE |= FAST_EXTENSION
LAST_BYTE |= NAT_TRAVERSAL
FLAGS = ['\0'] * 8
#FLAGS[0] = chr( AZUREUS )
FLAGS[5] = chr( UTORRENT )
FLAGS[7] = chr( LAST_BYTE )
FLAGS = ''.join(FLAGS)
protocol_name = 'BitTorrent protocol'

# for crypto
dh_prime = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563
PAD_MAX = 200 # less than protocol maximum, and later assumed to be < 256
DH_BYTES = 96
def bytetonum(x):
    return long(x.encode('hex'), 16)
def numtobyte(x):
    x = hex(x).lstrip('0x').rstrip('Ll')
    x = '0'*(192 - len(x)) + x
    return x.decode('hex')
  
if log_data:
    noisy = True

if noisy:
    connection_logger = logging.getLogger("BitTorrent.Connector")
    connection_logger.setLevel(logging.DEBUG)
    stream_handler = logging.StreamHandler()
    connection_logger.addHandler(stream_handler)
    log = connection_logger.debug


# Tracker NAT checking:
# Aside: When you start up a Torrent, the first connection after contacting
# the tracker is probably a callback from the tracker to perform a NatCheck.
# (I was a bit confused about where this connection was coming from that
# didn't have any bits set in the handshake's reserved bytes when
# with there were no other peers. Call me stupid.)   --Dave



class Connector(Handler):
    """Implements the syntax of the BitTorrent protocol. 
       See Upload.py and Download.py for the connection-level 
       semantics."""

    def __init__(self, parent, connection, id, is_local,
                 obfuscate_outgoing=False, log_prefix = "", lan=False):
        self.parent = parent
        self.connection = connection
        self.id = id
        self.ip = connection.ip
        self.port = connection.port
        self.addr = (self.ip, self.port)
        self.hostname = None
        self.locally_initiated = is_local
        if self.locally_initiated:
            self.max_message_length = self.parent.config['max_message_length']
            self.listening_port = self.port
        else:
            self.listening_port = None
        self.complete = False
        self.lan = lan
        self.closed = False
        self.got_anything = False
        self.next_upload = None
        self.upload = None
        self.download = None
        self._buffer = StringIO()
        self._reader = self._read_messages()
        self._next_len = self._reader.next()
        self._partial_message = None
        self._outqueue = StringIO()
        self._decrypt = None
        self._privkey = None        
        self.choke_sent = True
        
        self.uses_utorrent_extension = False
        self.uses_utorrent_pex = False
        self.uses_azureus_extension = False
        self.uses_azureus_pex = False
        self.uses_dht = False
        self.uses_fast_extension = False
        self.uses_nat_traversal = False

        self.obfuscate_outgoing = obfuscate_outgoing
        self.dht_port = None
        self.local_pex_set = set()
        self.remote_pex_set = set()
        self.sloppy_pre_connection_counter = 0
        self._sent_listeners = set()
        self.received_data = False
        self.log_prefix = log_prefix
        if self.locally_initiated:
            self.logger = logging.getLogger(
                self.log_prefix + '.' + repr(self.parent.infohash) +
                '.peer_id_not_yet')
        else:
            self.logger = logging.getLogger(
                self.log_prefix + '.infohash_not_yet.peer_id_not_yet')
        self.logger.setLevel(logging.DEBUG)

        if noisy:
            self.logger.addHandler(stream_handler)
            
        if self.locally_initiated:
            self.send_handshake()
        # Greg's comments: ow ow ow
        self.connection.handler = self

    def protocol_violation(self, s):
        msg = "%s %s" % (s, self.addr)
        if self.id:
            msg += " %r" % (identify_client(self.id), )
        if noisy:
            log("FAUX PAS: %s" % msg)
        self.logger.info(msg)

    def send_handshake(self):
        if self.obfuscate_outgoing:
            privkey = bytetonum(urandom(20))
            self._privkey = privkey
            pubkey = pow(2, privkey, dh_prime)
            out = numtobyte(pubkey) + urandom(randrange(PAD_MAX))
            self.connection.write(out)
        else:
            if noisy:
                l = [ c.encode('hex') for c in list(FLAGS) ]
                log("sending reserved: %s" % ' '.join(l))
            
            self.connection.write(''.join((chr(len(protocol_name)),
                                           protocol_name,
                                           FLAGS,
                                           self.parent.infohash)))
            # if we already have the peer's id, just send ours.
            # otherwise we wait for it.
            if self.id is not None:
                self.connection.write(self.parent.my_id)

    def set_parent(self, parent):
        self.parent = parent
        self.max_message_length = self.parent.config['max_message_length']

    def close(self):
        if noisy: log("CLOSE")
        if not self.closed:
            self.parent.remove_addr_from_cache(self.addr)
            self.connection.close()

    def send_interested(self):
        if noisy:
            log("SEND %s" % message_dict[INTERESTED])
        self._send_message(INTERESTED)

    def send_not_interested(self):
        if noisy:
            log("SEND %s" % message_dict[NOT_INTERESTED])
        self._send_message(NOT_INTERESTED)

    def send_choke(self):
        if self._partial_message is None:
            if noisy:
                log("SEND %s" % message_dict[CHOKE])
            self._send_message(CHOKE)
            self.choke_sent = True
            self.upload.sent_choke()

    def send_unchoke(self):
        if self._partial_message is None:
            if noisy:
                log("SEND %s" % message_dict[UNCHOKE])
            self._send_message(UNCHOKE)
            self.choke_sent = False

    def send_port(self, port):
        if noisy:
            log("SEND %s" % message_dict[PORT])
        self._send_message(PORT, pack('!H', port))
        
    def send_request(self, index, begin, length):
        if noisy:
            log("SEND %s %d %d %d" % (message_dict[REQUEST], index, begin, length))
        self._send_message(pack("!ciii", REQUEST, index, begin, length))

    def send_cancel(self, index, begin, length):
        self._send_message(pack("!ciii", CANCEL, index, begin, length))

    def send_bitfield(self, bitfield):
        if noisy:
            log("SEND %s" % message_dict[BITFIELD])
        self._send_message(BITFIELD, bitfield)

    def send_have(self, index):
        if noisy:
            log("SEND %s" % message_dict[HAVE])
        self._send_message(pack("!ci", HAVE, index))

    def send_have_all(self):
        assert(self.uses_fast_extension)
        if noisy:
            log("SEND %s" % message_dict[HAVE_ALL])
        self._send_message(pack("!c", HAVE_ALL))

    def send_have_none(self):
        assert(self.uses_fast_extension)
        if noisy:
            log("SEND %s" % message_dict[HAVE_NONE])
        self._send_message(pack("!c", HAVE_NONE))

    def send_reject_request(self, index, begin, length):
        assert(self.uses_fast_extension)
        self._send_message(pack("!ciii", REJECT_REQUEST, index, begin, length))

    def send_allowed_fast(self, index):
        assert(self.uses_fast_extension)
        self._send_message(pack("!ci", ALLOWED_FAST, index))

    def send_keepalive(self):
        self._send_message('')

    def send_holepunch_request(self, addr):
        # disabled, for now.
        return
    
        if not self.uses_nat_traversal:
            # maybe close?
            return
        d = {'t': 'r'}
        d['p'] = IPTools.compact(*addr)
        self._send_message(HOLE_PUNCH, d)

    def send_pex(self, pex_set):
        if not (self.uses_utorrent_extension and self.uses_utorrent_pex):
            return
        added = pex_set.difference(self.local_pex_set)
        dropped = self.local_pex_set.difference(pex_set)
        self.local_pex_set = pex_set
        if added or dropped:
            d = {}
            d['added'] = IPTools.compact_sequence(added)
            d['added.f'] = chr(0) * len(added) # hmm..
            d['dropped'] = IPTools.compact_sequence(dropped)
            self._send_message(UTORRENT_MSG,
                               UTORRENT_MSG_PEX, bencode(d))

    def add_sent_listener(self, listener):
        """Passed a function/functor that accepts a single byte argument,
           which is called everytime this uploader sends a chunk."""
        self._sent_listeners.add(listener)

    def remove_sent_listener(self, listener):
        self._sent_listeners.remove(listener)

    def fire_sent_listeners(self, bytes):
        for f in self._sent_listeners:
           f(bytes)

    def send_partial(self, bytes):
        if self.closed:
            return 0
        if self._partial_message is None and not self.upload.buffer:
            return 0
        if self._partial_message is None:
            buf = StringIO()
            while self.upload.buffer and buf.tell() < bytes:
                t, piece = self.upload.buffer.pop(0)
                index, begin, length = t
                msg = pack("!icii%s" % len(piece), len(piece) + 9, PIECE,
                           index, begin)
                buf.write(msg)
                buf.write(piece)
                if noisy: log("SEND PIECE %d %d" % (index,begin))
            self._partial_message = buf.getvalue()
        if bytes < len(self._partial_message):
            self.fire_sent_listeners(bytes)
            self.connection.write(buffer(self._partial_message, 0, bytes))
            self._partial_message = buffer(self._partial_message, bytes)
            return bytes
        buf = StringIO()
        buf.write(self._partial_message)
        self._partial_message = None
        if self.choke_sent != self.upload.choked:
            if self.upload.choked:
                self._outqueue.write(pack("!ic", 1, CHOKE))
                self.upload.sent_choke()
            else:
                self._outqueue.write(pack("!ic", 1, UNCHOKE))
            self.choke_sent = self.upload.choked
        buf.write(self._outqueue.getvalue())
        self._outqueue.truncate(0)
        queue = buf.getvalue()        
        self.fire_sent_listeners(len(queue))
        self.connection.write(queue)
        return len(queue)

    # yields the number of bytes it wants next, gets those in self._message
    def _read_messages(self):

        # be compatible with encrypted clients. Thanks Uoti        
        yield 1 + len(protocol_name)
        if self._privkey is not None or \
           self._message != chr(len(protocol_name)) + protocol_name:
            if self.locally_initiated:
                if self._privkey is None:
                    return
                dhstr = self._message
                yield DH_BYTES - len(dhstr)
                dhstr += self._message
                pub = bytetonum(dhstr)
                S = numtobyte(pow(pub, self._privkey, dh_prime))
                pub = self._privkey = dhstr = None
                SKEY = self.parent.infohash
                x = sha('req3' + S).digest()
                streamid = sha('req2'+SKEY).digest()
                streamid = ''.join([chr(ord(streamid[i]) ^ ord(x[i]))
                                    for i in range(20)])
                encrypt = ARC4.new(sha('keyA' + S + SKEY).digest()).encrypt
                encrypt('x'*1024)
                padlen = randrange(PAD_MAX)
                x = sha('req1' + S).digest() + streamid + encrypt(
                    '\x00'*8 + '\x00'*3+'\x02'+'\x00'+chr(padlen)+
                    urandom(padlen)+'\x00\x00')
                self.connection.write(x)
                self.connection.encrypt = encrypt
                decrypt = ARC4.new(sha('keyB' + S + SKEY).digest()).decrypt
                decrypt('x'*1024)
                VC = decrypt('\x00'*8) # actually encrypt
                x = ''
                while 1:
                    yield 1
                    x += self._message
                    i = (x + str(self._rest)).find(VC)
                    if i >= 0:
                        break
                    yield len(self._rest)
                    x += self._message
                    if len(x) >= 520:
                        self.protocol_violation('VC not found')
                        return
                yield i + 8 + 4 + 2 - len(x)
                x = decrypt((x + self._message)[-6:])
                self._decrypt = decrypt
                if x[0:4] != '\x00\x00\x00\x02':
                    self.protocol_violation('bad crypto method selected, not 2')
                    return
                padlen = (ord(x[4]) << 8) + ord(x[5])
                if padlen > 512:
                    self.protocol_violation('padlen too long')
                    return
                self.connection.write(''.join((chr(len(protocol_name)),
                                               protocol_name, FLAGS,
                                               self.parent.infohash)))
                yield padlen
            else:
                dhstr = self._message
                yield DH_BYTES - len(dhstr)
                dhstr += self._message
                privkey = bytetonum(urandom(20))
                pub = numtobyte(pow(2, privkey, dh_prime))
                self.connection.write(''.join((pub, urandom(randrange(PAD_MAX)))))
                pub = bytetonum(dhstr)
                S = numtobyte(pow(pub, privkey, dh_prime))
                dhstr = pub = privkey = None
                streamid = sha('req1' + S).digest()
                x = ''
                while 1:
                    yield 1
                    x += self._message
                    i = (x + str(self._rest)).find(streamid)
                    if i >= 0:
                        break
                    yield len(self._rest)
                    x += self._message
                    if len(x) >= 532:
                        self.protocol_violation('incoming VC not found')
                        return
                yield i + 20 + 20 + 8 + 4 + 2 - len(x)
                self._message = (x + self._message)[-34:]
                streamid = self._message[0:20]
                x = sha('req3' + S).digest()
                streamid = ''.join([chr(ord(streamid[i]) ^ ord(x[i]))
                                    for i in range(20)])
                self.parent.select_torrent_obfuscated(self, streamid)
                if self.parent.infohash is None:
                    self.protocol_violation('download id unknown/rejected')
                    return
                self.logger = logging.getLogger(
                    self.log_prefix + '.' + repr(self.parent.infohash) +
                    '.peer_id_not_yet')
                SKEY = self.parent.infohash
                decrypt = ARC4.new(sha('keyA' + S + SKEY).digest()).decrypt
                decrypt('x'*1024)
                s = decrypt(self._message[20:34])
                if s[0:8] != '\x00' * 8:
                    self.protocol_violation('BAD VC')
                    return
                crypto_provide = toint(s[8:12])
                padlen = (ord(s[12]) << 8) + ord(s[13])
                if padlen > 512:
                    self.protocol_violation('BAD padlen, too long')
                    return
                self._decrypt = decrypt
                yield padlen + 2
                s = self._message
                encrypt = ARC4.new(sha('keyB' + S + SKEY).digest()).encrypt
                encrypt('x'*1024)
                self.connection.encrypt = encrypt
                if not crypto_provide & 2:
                    self.protocol_violation("peer doesn't support crypto mode 2")
                    return
                padlen = randrange(PAD_MAX)
                s = '\x00' * 11 + '\x02\x00' + chr(padlen) + urandom(padlen)
                self.connection.write(s)
            S = SKEY = s = x = streamid = VC = padlen = None
            yield 1 + len(protocol_name)
            if self._message != chr(len(protocol_name)) + protocol_name:
                self.protocol_violation('classic handshake fails')
                return

        yield 8  # reserved
        if noisy:
            l = [ c.encode('hex') for c in list(self._message) ]
            log("reserved: %s" % ' '.join(l))

        if ord(self._message[0]) & AZUREUS:
            if noisy: log("Implements Azureus extensions")
            if ord(FLAGS[0]) & AZUREUS:
                self.uses_azureus_extension = True
        if ord(self._message[5]) & UTORRENT:
            if noisy: log("Implements uTorrent extensions")
            if ord(FLAGS[5]) & UTORRENT:
                self.uses_utorrent_extension = True
        if ord(self._message[7]) & DHT:
            if noisy: log("Implements DHT")
            if ord(FLAGS[7]) & DHT:
                self.uses_dht = True
        if ord(self._message[7]) & FAST_EXTENSION:
            if noisy: log("Implements FAST_EXTENSION")
            if not disable_fast_extension:
                self.uses_fast_extension = True
        if ord(self._message[7]) & NAT_TRAVERSAL:
            if noisy: log("Implements NAT_TRAVERSAL")
            if ord(FLAGS[7]) & NAT_TRAVERSAL:
                self.uses_nat_traversal = True
            
        
        yield 20 # download id (i.e., infohash)
        if self.parent.infohash is None:  # incoming connection
            # modifies self.parent if successful
            self.parent.select_torrent(self, self._message)
            if self.parent.infohash is None:
                self.protocol_violation("no infohash from parent (peer from a "
                                        "torrent you're not running)")
                return
        elif self._message != self.parent.infohash:
            self.protocol_violation("incorrect infohash from parent")
            return

        if not self.locally_initiated:
            self.connection.write(''.join((chr(len(protocol_name)),
                                           protocol_name, FLAGS,
                                           self.parent.infohash,
                                           self.parent.my_id)))
            
        yield 20  # peer id
        if noisy: log("peer id: %r" % self._message)
        # if we don't already have the peer's id, send ours
        if not self.id:
            self.id = self._message
            ns = (self.log_prefix + '.' + repr(self.parent.infohash) +
                  '.' + self._message.encode('hex'))
            self.logger = logging.getLogger(ns)

            if self.id == self.parent.my_id:
                self.protocol_violation("talking to self")
                return

            if self.id in self.parent.connector_ids:
                self.protocol_violation("duplicate connection (id collision)")
                return
            if (self.parent.config['one_connection_per_ip'] and
                self.ip in self.parent.connector_ips):
                self.protocol_violation("duplicate connection (ip collision)")
                return

            if self.locally_initiated:
                self.connection.write(self.parent.my_id)
            else:
                self.parent.everinc = True
        else:
            # assert the id we have and the one we got are the same
            if self._message != self.id:
                self.protocol_violation("incorrect id have:%r got:%r" % (self.id, self._message))
                return
        self.complete = True
        self.parent.connection_handshake_completed(self)

        if self.uses_utorrent_extension:
            response = {'m':{'ut_pex':1},
                        'v': '\xb5Torrent 1.5',
                        'e': 0,
                        'p': self.parent.reported_port,
                        }
            response = bencode(response)
            self._send_message(UTORRENT_MSG,
                               UTORRENT_MSG_INFO, response)

        message_count = 0
        while True:
            yield 4   # message length
            l = toint(self._message)
            if l > self.max_message_length:
                d = '%s%s' % (self._message, self._rest)
                d = d[:10]
                self.protocol_violation("message length exceeds max "
                                        "(%s > %s): %r, count:%d" %
                                        (l, self.max_message_length, d,
                                         message_count))
                return
            if l > 0:
                yield l
                self._got_message(self._message)
                message_count += 1

    def _got_utorrent_msg(self, msg_type, d):
        if msg_type == UTORRENT_MSG_INFO:
            version = d.get('v')
            port = d.get('p')
            if port:
                self.listening_port = int(port)
            encryption = d.get('e')
            messages = d.get('m')
            self.uses_utorrent_pex = bool(messages.get('ut_pex', 0))
        elif msg_type == UTORRENT_MSG_PEX:
            for addr in IPTools.uncompact_sequence(d['added']):
                self.remote_pex_set.add(addr)
                self.parent.start_connection(addr)
            dropped_gen = IPTools.uncompact_sequence(d['dropped'])
            self.remote_pex_set.difference_update(dropped_gen)        

    def _got_azureus_msg(self, msg_type, d):
        port = d.get('tcp_port')
        if port:
            self.listening_port = int(port)
        m = d.get('messages', [])
        for msg in m:
            if msg.get('id') == 'AZ_PEER_EXCHANGE':
                self.uses_azureus_pex = True

    def _got_holepunch_msg(self, d):
        msg_type = d.get('t')
        if msg_type == 'r': # request
            print 'hole punch requested from', self.addr, 'to', d['p']

            d = {'t': 'i'}
            d['p'] = IPTools.compact(addr)
            self._send_message(HOLE_PUNCH, d)

        elif msg_type == 'i': # initiate
            print 'told to initiate connection(s) to:' + str(d['p'])
        else:
            self.protocol_violation("unknown hole punch msg type: %r" %
                                    msg_type)
        
    def _got_message(self, message):
        t = message[0]
        if t in [BITFIELD, HAVE_ALL, HAVE_NONE] and self.got_anything:
            self.protocol_violation("%s after got anything" % message_dict[t])
            self.close()
            return
        if t == UTORRENT_MSG and self.uses_utorrent_extension:
            msg_type = message[1]
            d = bdecode(message[2:])
            if noisy: log("UTORRENT_MSG: %r:%r" % (msg_type, d))
            self._got_utorrent_msg(msg_type, d)
            return
        if t == AZUREUS_SUCKS and self.uses_azureus_extension:
            magic_intro = 17
            msg_type = message[:magic_intro]
            d = bdecode(message[magic_intro:])
            if noisy: log("AZUREUS_MSG: %r:%r" % (msg_type, d))
            self._got_azureus_msg(msg_type, d)
            return
        if t == HOLE_PUNCH and self.uses_nat_traversal:
            d = ebdecode(message)
            if noisy: log("HOLE_PUNCH: %r" % d)
            self._got_holepunch_msg(d)
            return
            
        self.got_anything = True
        if (t in (CHOKE, UNCHOKE, INTERESTED, NOT_INTERESTED,
                  HAVE_ALL, HAVE_NONE) and
                len(message) != 1):
            self.protocol_violation("%s with message length %d" %
                                    (message_dict[t], len(message)))
            if noisy: log("UNKNOWN: %r" % message)
            self.close()
            return
        if t == CHOKE:
            if noisy: log("GOT %s" % message_dict[t])
            self.download.got_choke()
        elif t == UNCHOKE:
            if noisy: log("GOT %s" % message_dict[t])
            self.download.got_unchoke()
        elif t == INTERESTED:
            if noisy: log("GOT %s" % message_dict[t])
            self.upload.got_interested()
        elif t == NOT_INTERESTED:
            if noisy: log("GOT %s" % message_dict[t])
            self.upload.got_not_interested()
        elif t == HAVE:
            if len(message) != 5:
                self.protocol_violation("HAVE length: %d != 5" %
                                        len(message))
                self.close()
                return
            i = unpack("!xi", message)[0]
            if noisy: log("GOT HAVE %d" % i)
            if i >= self.parent.numpieces:
                self.protocol_violation("HAVE %d >= %d" %
                                        (i, self.parent.numpieces))
                self.close()
                return
            self.download.got_have(i)
        elif t == BITFIELD:
            try:
                b = Bitfield(self.parent.numpieces, message[1:])
            except ValueError, e:
                self.protocol_violation("BITFIELD %s" %
                                        (e,))
                self.close()
                return
            self.download.got_have_bitfield(b)
        elif t == REQUEST:
            if len(message) != 13:
                self.protocol_violation("REQUEST length %d != 13" %
                                        len(message))
                self.close()
                return
            i, a, b = unpack("!xiii", message)
            if noisy: log("GOT REQUEST %d %d %d" % (i, a, b))
            if i >= self.parent.numpieces:
                self.protocol_violation(
                     "Requested piece index out of range: %d > %d" %
                     (i, self.parent.numpieces))
                self.close()
                return
            if a + b > self.parent.piece_size:
                self.protocol_violation(
                     "Requested range exceeds piece size: "
                     "(b:%d + l:%d == %d) > %d" %
                     (a, b, a + b, self.parent.piece_size))
                self.close()
                return                
            if self.download.have[i]:
                self.protocol_violation(
                     "Requested piece index %d which the peer already has" %
                     (i,))
                self.close()
                return
            self.upload.got_request(i, a, b)
        elif t == CANCEL:
            if len(message) != 13:
                self.protocol_violation("CANCEL length %d != 13" %
                                        len(message))
                self.close()
                return
            i, a, b = unpack("!xiii", message)
            if noisy: log("GOT CANCEL %d %d %d" % (i, a, b))
            if i >= self.parent.numpieces:
                self.protocol_violation(
                     "Cancelled piece index %d > numpieces which is %d" %
                     (i,self.parent.numpieces))
                self.close()
                return
            self.upload.got_cancel(i, a, b)
        elif t == PIECE:
            if len(message) <= 9:
                self.protocol_violation("PIECE %d <= 9" %
                                        len(message))
                self.close()
                return
            n = len(message) - 9
            i, a, b = unpack("!xii%ss" % n, message)
            if noisy: log("GOT PIECE %d %d" % (i, a))
            if i >= self.parent.numpieces:
                self.protocol_violation("PIECE %d >= %d" %
                                        (i, self.parent.numpieces))
                self.close()
                return
            self.download.got_piece(i, a, b)
        elif t == PORT:
            if len(message) != 3:
                self.protocol_violation("PORT %d != 3" %
                                        len(message))
                self.close()
                return
            self.dht_port = unpack('!H', message[1:3])[0]
            self.parent.got_port(self)
        elif t == SUGGEST_PIECE:
            if not self.uses_fast_extension:
                self.protocol_violation(
                    "Received 'SUGGEST_PIECE' when fast extension disabled.")
                self.close()
                return
            if len(message) != 5:
                self.protocol_violation("SUGGEST_PIECE length: %d != 5" %
                                        len(message))
                self.close()
                return
            i = unpack("!xi", message)[0]
            if noisy: log("GOT SUGGEST_PIECE %d" % i)
            if i >= self.parent.numpieces:
                self.protocol_violation(
                    "Received 'SUGGEST_PIECE' with piece id %d > numpieces." %
                    self.parent.numpieces)
                self.close()
                return
            self.download.got_suggest_piece(i)
        elif t == HAVE_ALL:
            if noisy: log("GOT %s" % message_dict[t])
            if not self.uses_fast_extension:
                self.protocol_violation(
                    "Received 'HAVE_ALL' when fast extension disabled.")
                self.close()
                return
            self.download.got_have_all()
        elif t == HAVE_NONE:
            if noisy: log("GOT %s" % message_dict[t])
            if not self.uses_fast_extension:
                self.protocol_violation(
                    "Received 'HAVE_NONE' when fast extension disabled.")
                self.close()
                return
            self.download.got_have_none()
        elif t == REJECT_REQUEST:
            if not self.uses_fast_extension:
                self.protocol_violation(
                    "Received 'REJECT_REQUEST' when fast extension disabled.")
                self.close()
                return
            if len(message) != 13:
                self.protocol_violation(
                    "Received 'REJECT_REQUEST' with length %d != 13." %
                    len(message))
                self.close()
                return
            i, a, b = unpack("!xiii", message)
            if noisy: log("GOT REJECT_REQUEST %d %d" % (i,a))
            if i >= self.parent.numpieces:
                self.protocol_violation("REJECT %d >= %d" %
                                        (i, self.parent.numpieces))
                self.close()
                return
            self.download.got_reject_request(i, a, b)
        elif t == ALLOWED_FAST:
            if not self.uses_fast_extension:
                self.protocol_violation(
                    "Received 'ALLOWED_FAST' when fast extension disabled.")
                self.close()
                return
            if len(message) != 5:
                self.protocol_violation("ALLOWED_FAST length: %d != 5" %
                                        len(message))
                self.close()
                return
            i = unpack("!xi", message)[0]
            if noisy: log("GOT ALLOWED_FAST %d" % i)
            self.download.got_allowed_fast(i)
        else:
            if noisy: log("GOT %s length %d" % (message_dict[t], len(message)))
            self.protocol_violation("unhandled message %s" % message_dict[t])
            self.close()

    def _send_message(self, *msg_a):
        if self.closed:
            return
        l = 0
        for e in msg_a:
            l += len(e)
        d = [tobinary(l), ]
        d.extend(msg_a)
        s = ''.join(d)
        if self._partial_message is not None:
            self._outqueue.write(s)
        else:
            self.connection.write(s)

    def data_came_in(self, conn, s):
        self.received_data = True
        if not self.download:
            # this is really annoying.
            self.sloppy_pre_connection_counter += len(s)
        else:
            l = self.sloppy_pre_connection_counter + len(s)
            self.sloppy_pre_connection_counter = 0

        if log_data:
            assert self.addr == (conn.ip, conn.port)
            open('%s_%d.log' % self.addr, 'ab').write(s)
            
        while True:
            if self.closed:
                return
            i = self._next_len - self._buffer.tell()
            if i > len(s):
                # not enough bytes, keep buffering
                self._buffer.write(s)
                return
            if self._buffer.tell() > 0:
                # collect buffer + current for message
                self._buffer.write(buffer(s, 0, i))
                m = self._buffer.getvalue()
                # optimize for cpu (reduce mallocs)
                #self._buffer.truncate(0)
                # optimize for memory (free buffer memory)
                self._buffer.close()
                self._buffer = StringIO()
            else:
                # painful string copy
                m = s[:i]
            s = buffer(s, i)
            if self._decrypt is not None:
                m = self._decrypt(m)
            self._message = m
            self._rest = s
            try:
                self._next_len = self._reader.next()
            except StopIteration:
                self.close()
                return
            except:
                self.logger.exception("Message parsing failed")
                self.close()
                return

    def _optional_restart(self):
        if self.locally_initiated and not self.received_data and not self.obfuscate_outgoing:
            self.parent.start_connection(self.addr, id=None, encrypt=True)

    def connection_lost(self, conn):
        assert conn is self.connection
        self.closed = True
        self._reader = None
        self.parent.connection_lost(self)

        self._optional_restart()

        self.connection = None
        if self.complete:
            if self.download is not None:
                self.download.disconnected()
            self.upload = None
            self.download = None

    def connection_flushed(self, connection):
        if (self.complete and self.next_upload is None and 
            (self._partial_message is not None
             or (self.upload and self.upload.buffer))):
            if self.lan:
                # bypass upload rate limiter
                self.send_partial(self.parent.ratelimiter.unitsize)
            else:
                self.parent.ratelimiter.queue(self)