"SfR Fresh" - the SfR Freeware/Shareware Archive 
Member "BitTorrent-4.26.0/BitTorrent/Connector.py" of archive BitTorrent-4.26.0.tar.gz:
# The contents of this file are subject to the BitTorrent Open Source License
# Version 1.1 (the License). You may not copy or use this file, in either
# source code or executable form, except in compliance with the License. You
# may obtain a copy of the License at http://www.bittorrent.com/license/.
#
# Software distributed under the License is distributed on an AS IS basis,
# WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
# for the specific language governing rights and limitations under the
# License.
# Originally written by Bram Cohen, heavily modified by Uoti Urpala
# Fast extensions added by David Harrison
from __future__ import generators
# DEBUG
# If you think FAST_EXTENSION is causing problems then set the following:
disable_fast_extension = False
# END DEBUG
noisy = False
log_data = False
# for crypto
from random import randrange
from BTL.hash import sha
from Crypto.Cipher import ARC4
# urandom comes from obsoletepythonsupport
import struct
from struct import pack, unpack
from cStringIO import StringIO
from BTL.bencode import bencode, bdecode
from BitTorrent.RawServer_twisted import Handler
from BTL.bitfield import Bitfield
from BTL import IPTools
from BTL.obsoletepythonsupport import *
from BitTorrent.ClientIdentifier import identify_client
import logging
def toint(s):
return struct.unpack("!i", s)[0]
def tobinary(i):
return struct.pack("!i", i)
class BTMessages(object):
def __init__(self, messages):
self.message_to_chr = {}
self.chr_to_message = {}
for o, v in messages.iteritems():
c = chr(o)
self.chr_to_message[c] = v
self.message_to_chr[v] = c
def __getitem__(self, key):
return self.chr_to_message.get(key, "UNKNOWN: %r" % key)
message_dict = BTMessages({
0:'CHOKE',
1:'UNCHOKE',
2:'INTERESTED',
3:'NOT_INTERESTED',
4:'HAVE',
# index, bitfield
5:'BITFIELD',
# index, begin, length
6:'REQUEST',
# index, begin, piece
7:'PIECE',
# index, begin, piece
8:'CANCEL',
# 2-byte port message
9:'PORT',
# no args
10:'WANT_METAINFO',
11:'METAINFO',
# index
12:'SUSPECT_PIECE',
# no args
13:'SUGGEST_PIECE', # FAST_EXTENSION
14:'HAVE_ALL', # FAST_EXTENSION
15:'HAVE_NONE', # FAST_EXTENSION
# index, begin, length
16:'REJECT_REQUEST', # FAST_EXTENSION
# index
17:'ALLOWED_FAST', # FAST_EXTENSION
# compact_addr
18:'HOLE_PUNCH', # NAT_TRAVERSAL
# message id, bencoded payload
20:'UTORRENT_MSG', # UTORRENT
})
# put all the message identifiers in the module
locals().update(message_dict.message_to_chr)
# I am not even shitting you.
AZUREUS_SUCKS = CHOKE
UTORRENT_MSG_INFO = chr(0)
UTORRENT_MSG_PEX = chr(1)
# reserved flags:
# reserved[0]
# 0x80 Azureus Messaging Protocol
AZUREUS = 0x80
# reserved[5]
# 0x10 uTorrent extensions: peer exchange, encrypted connections,
# broadcast listen port.
UTORRENT = 0x10
# reserved[7]
DHT = 0x01
FAST_EXTENSION = 0x04 # suggest, haveall, havenone, reject request,
# and allow fast extensions.
NAT_TRAVERSAL = 0x08 # holepunch
LAST_BYTE = DHT
if not disable_fast_extension:
LAST_BYTE |= FAST_EXTENSION
LAST_BYTE |= NAT_TRAVERSAL
FLAGS = ['\0'] * 8
#FLAGS[0] = chr( AZUREUS )
FLAGS[5] = chr( UTORRENT )
FLAGS[7] = chr( LAST_BYTE )
FLAGS = ''.join(FLAGS)
protocol_name = 'BitTorrent protocol'
# for crypto
dh_prime = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563
PAD_MAX = 200 # less than protocol maximum, and later assumed to be < 256
DH_BYTES = 96
def bytetonum(x):
return long(x.encode('hex'), 16)
def numtobyte(x):
x = hex(x).lstrip('0x').rstrip('Ll')
x = '0'*(192 - len(x)) + x
return x.decode('hex')
if log_data:
noisy = True
if noisy:
connection_logger = logging.getLogger("BitTorrent.Connector")
connection_logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
connection_logger.addHandler(stream_handler)
log = connection_logger.debug
# Tracker NAT checking:
# Aside: When you start up a Torrent, the first connection after contacting
# the tracker is probably a callback from the tracker to perform a NatCheck.
# (I was a bit confused about where this connection was coming from that
# didn't have any bits set in the handshake's reserved bytes when
# with there were no other peers. Call me stupid.) --Dave
class Connector(Handler):
"""Implements the syntax of the BitTorrent protocol.
See Upload.py and Download.py for the connection-level
semantics."""
def __init__(self, parent, connection, id, is_local,
obfuscate_outgoing=False, log_prefix = "", lan=False):
self.parent = parent
self.connection = connection
self.id = id
self.ip = connection.ip
self.port = connection.port
self.addr = (self.ip, self.port)
self.hostname = None
self.locally_initiated = is_local
if self.locally_initiated:
self.max_message_length = self.parent.config['max_message_length']
self.listening_port = self.port
else:
self.listening_port = None
self.complete = False
self.lan = lan
self.closed = False
self.got_anything = False
self.next_upload = None
self.upload = None
self.download = None
self._buffer = StringIO()
self._reader = self._read_messages()
self._next_len = self._reader.next()
self._partial_message = None
self._outqueue = StringIO()
self._decrypt = None
self._privkey = None
self.choke_sent = True
self.uses_utorrent_extension = False
self.uses_utorrent_pex = False
self.uses_azureus_extension = False
self.uses_azureus_pex = False
self.uses_dht = False
self.uses_fast_extension = False
self.uses_nat_traversal = False
self.obfuscate_outgoing = obfuscate_outgoing
self.dht_port = None
self.local_pex_set = set()
self.remote_pex_set = set()
self.sloppy_pre_connection_counter = 0
self._sent_listeners = set()
self.received_data = False
self.log_prefix = log_prefix
if self.locally_initiated:
self.logger = logging.getLogger(
self.log_prefix + '.' + repr(self.parent.infohash) +
'.peer_id_not_yet')
else:
self.logger = logging.getLogger(
self.log_prefix + '.infohash_not_yet.peer_id_not_yet')
self.logger.setLevel(logging.DEBUG)
if noisy:
self.logger.addHandler(stream_handler)
if self.locally_initiated:
self.send_handshake()
# Greg's comments: ow ow ow
self.connection.handler = self
def protocol_violation(self, s):
msg = "%s %s" % (s, self.addr)
if self.id:
msg += " %r" % (identify_client(self.id), )
if noisy:
log("FAUX PAS: %s" % msg)
self.logger.info(msg)
def send_handshake(self):
if self.obfuscate_outgoing:
privkey = bytetonum(urandom(20))
self._privkey = privkey
pubkey = pow(2, privkey, dh_prime)
out = numtobyte(pubkey) + urandom(randrange(PAD_MAX))
self.connection.write(out)
else:
if noisy:
l = [ c.encode('hex') for c in list(FLAGS) ]
log("sending reserved: %s" % ' '.join(l))
self.connection.write(''.join((chr(len(protocol_name)),
protocol_name,
FLAGS,
self.parent.infohash)))
# if we already have the peer's id, just send ours.
# otherwise we wait for it.
if self.id is not None:
self.connection.write(self.parent.my_id)
def set_parent(self, parent):
self.parent = parent
self.max_message_length = self.parent.config['max_message_length']
def close(self):
if noisy: log("CLOSE")
if not self.closed:
self.parent.remove_addr_from_cache(self.addr)
self.connection.close()
def send_interested(self):
if noisy:
log("SEND %s" % message_dict[INTERESTED])
self._send_message(INTERESTED)
def send_not_interested(self):
if noisy:
log("SEND %s" % message_dict[NOT_INTERESTED])
self._send_message(NOT_INTERESTED)
def send_choke(self):
if self._partial_message is None:
if noisy:
log("SEND %s" % message_dict[CHOKE])
self._send_message(CHOKE)
self.choke_sent = True
self.upload.sent_choke()
def send_unchoke(self):
if self._partial_message is None:
if noisy:
log("SEND %s" % message_dict[UNCHOKE])
self._send_message(UNCHOKE)
self.choke_sent = False
def send_port(self, port):
if noisy:
log("SEND %s" % message_dict[PORT])
self._send_message(PORT, pack('!H', port))
def send_request(self, index, begin, length):
if noisy:
log("SEND %s %d %d %d" % (message_dict[REQUEST], index, begin, length))
self._send_message(pack("!ciii", REQUEST, index, begin, length))
def send_cancel(self, index, begin, length):
self._send_message(pack("!ciii", CANCEL, index, begin, length))
def send_bitfield(self, bitfield):
if noisy:
log("SEND %s" % message_dict[BITFIELD])
self._send_message(BITFIELD, bitfield)
def send_have(self, index):
if noisy:
log("SEND %s" % message_dict[HAVE])
self._send_message(pack("!ci", HAVE, index))
def send_have_all(self):
assert(self.uses_fast_extension)
if noisy:
log("SEND %s" % message_dict[HAVE_ALL])
self._send_message(pack("!c", HAVE_ALL))
def send_have_none(self):
assert(self.uses_fast_extension)
if noisy:
log("SEND %s" % message_dict[HAVE_NONE])
self._send_message(pack("!c", HAVE_NONE))
def send_reject_request(self, index, begin, length):
assert(self.uses_fast_extension)
self._send_message(pack("!ciii", REJECT_REQUEST, index, begin, length))
def send_allowed_fast(self, index):
assert(self.uses_fast_extension)
self._send_message(pack("!ci", ALLOWED_FAST, index))
def send_keepalive(self):
self._send_message('')
def send_holepunch_request(self, addr):
# disabled, for now.
return
if not self.uses_nat_traversal:
# maybe close?
return
d = {'t': 'r'}
d['p'] = IPTools.compact(*addr)
self._send_message(HOLE_PUNCH, d)
def send_pex(self, pex_set):
if not (self.uses_utorrent_extension and self.uses_utorrent_pex):
return
added = pex_set.difference(self.local_pex_set)
dropped = self.local_pex_set.difference(pex_set)
self.local_pex_set = pex_set
if added or dropped:
d = {}
d['added'] = IPTools.compact_sequence(added)
d['added.f'] = chr(0) * len(added) # hmm..
d['dropped'] = IPTools.compact_sequence(dropped)
self._send_message(UTORRENT_MSG,
UTORRENT_MSG_PEX, bencode(d))
def add_sent_listener(self, listener):
"""Passed a function/functor that accepts a single byte argument,
which is called everytime this uploader sends a chunk."""
self._sent_listeners.add(listener)
def remove_sent_listener(self, listener):
self._sent_listeners.remove(listener)
def fire_sent_listeners(self, bytes):
for f in self._sent_listeners:
f(bytes)
def send_partial(self, bytes):
if self.closed:
return 0
if self._partial_message is None and not self.upload.buffer:
return 0
if self._partial_message is None:
buf = StringIO()
while self.upload.buffer and buf.tell() < bytes:
t, piece = self.upload.buffer.pop(0)
index, begin, length = t
msg = pack("!icii%s" % len(piece), len(piece) + 9, PIECE,
index, begin)
buf.write(msg)
buf.write(piece)
if noisy: log("SEND PIECE %d %d" % (index,begin))
self._partial_message = buf.getvalue()
if bytes < len(self._partial_message):
self.fire_sent_listeners(bytes)
self.connection.write(buffer(self._partial_message, 0, bytes))
self._partial_message = buffer(self._partial_message, bytes)
return bytes
buf = StringIO()
buf.write(self._partial_message)
self._partial_message = None
if self.choke_sent != self.upload.choked:
if self.upload.choked:
self._outqueue.write(pack("!ic", 1, CHOKE))
self.upload.sent_choke()
else:
self._outqueue.write(pack("!ic", 1, UNCHOKE))
self.choke_sent = self.upload.choked
buf.write(self._outqueue.getvalue())
self._outqueue.truncate(0)
queue = buf.getvalue()
self.fire_sent_listeners(len(queue))
self.connection.write(queue)
return len(queue)
# yields the number of bytes it wants next, gets those in self._message
def _read_messages(self):
# be compatible with encrypted clients. Thanks Uoti
yield 1 + len(protocol_name)
if self._privkey is not None or \
self._message != chr(len(protocol_name)) + protocol_name:
if self.locally_initiated:
if self._privkey is None:
return
dhstr = self._message
yield DH_BYTES - len(dhstr)
dhstr += self._message
pub = bytetonum(dhstr)
S = numtobyte(pow(pub, self._privkey, dh_prime))
pub = self._privkey = dhstr = None
SKEY = self.parent.infohash
x = sha('req3' + S).digest()
streamid = sha('req2'+SKEY).digest()
streamid = ''.join([chr(ord(streamid[i]) ^ ord(x[i]))
for i in range(20)])
encrypt = ARC4.new(sha('keyA' + S + SKEY).digest()).encrypt
encrypt('x'*1024)
padlen = randrange(PAD_MAX)
x = sha('req1' + S).digest() + streamid + encrypt(
'\x00'*8 + '\x00'*3+'\x02'+'\x00'+chr(padlen)+
urandom(padlen)+'\x00\x00')
self.connection.write(x)
self.connection.encrypt = encrypt
decrypt = ARC4.new(sha('keyB' + S + SKEY).digest()).decrypt
decrypt('x'*1024)
VC = decrypt('\x00'*8) # actually encrypt
x = ''
while 1:
yield 1
x += self._message
i = (x + str(self._rest)).find(VC)
if i >= 0:
break
yield len(self._rest)
x += self._message
if len(x) >= 520:
self.protocol_violation('VC not found')
return
yield i + 8 + 4 + 2 - len(x)
x = decrypt((x + self._message)[-6:])
self._decrypt = decrypt
if x[0:4] != '\x00\x00\x00\x02':
self.protocol_violation('bad crypto method selected, not 2')
return
padlen = (ord(x[4]) << 8) + ord(x[5])
if padlen > 512:
self.protocol_violation('padlen too long')
return
self.connection.write(''.join((chr(len(protocol_name)),
protocol_name, FLAGS,
self.parent.infohash)))
yield padlen
else:
dhstr = self._message
yield DH_BYTES - len(dhstr)
dhstr += self._message
privkey = bytetonum(urandom(20))
pub = numtobyte(pow(2, privkey, dh_prime))
self.connection.write(''.join((pub, urandom(randrange(PAD_MAX)))))
pub = bytetonum(dhstr)
S = numtobyte(pow(pub, privkey, dh_prime))
dhstr = pub = privkey = None
streamid = sha('req1' + S).digest()
x = ''
while 1:
yield 1
x += self._message
i = (x + str(self._rest)).find(streamid)
if i >= 0:
break
yield len(self._rest)
x += self._message
if len(x) >= 532:
self.protocol_violation('incoming VC not found')
return
yield i + 20 + 20 + 8 + 4 + 2 - len(x)
self._message = (x + self._message)[-34:]
streamid = self._message[0:20]
x = sha('req3' + S).digest()
streamid = ''.join([chr(ord(streamid[i]) ^ ord(x[i]))
for i in range(20)])
self.parent.select_torrent_obfuscated(self, streamid)
if self.parent.infohash is None:
self.protocol_violation('download id unknown/rejected')
return
self.logger = logging.getLogger(
self.log_prefix + '.' + repr(self.parent.infohash) +
'.peer_id_not_yet')
SKEY = self.parent.infohash
decrypt = ARC4.new(sha('keyA' + S + SKEY).digest()).decrypt
decrypt('x'*1024)
s = decrypt(self._message[20:34])
if s[0:8] != '\x00' * 8:
self.protocol_violation('BAD VC')
return
crypto_provide = toint(s[8:12])
padlen = (ord(s[12]) << 8) + ord(s[13])
if padlen > 512:
self.protocol_violation('BAD padlen, too long')
return
self._decrypt = decrypt
yield padlen + 2
s = self._message
encrypt = ARC4.new(sha('keyB' + S + SKEY).digest()).encrypt
encrypt('x'*1024)
self.connection.encrypt = encrypt
if not crypto_provide & 2:
self.protocol_violation("peer doesn't support crypto mode 2")
return
padlen = randrange(PAD_MAX)
s = '\x00' * 11 + '\x02\x00' + chr(padlen) + urandom(padlen)
self.connection.write(s)
S = SKEY = s = x = streamid = VC = padlen = None
yield 1 + len(protocol_name)
if self._message != chr(len(protocol_name)) + protocol_name:
self.protocol_violation('classic handshake fails')
return
yield 8 # reserved
if noisy:
l = [ c.encode('hex') for c in list(self._message) ]
log("reserved: %s" % ' '.join(l))
if ord(self._message[0]) & AZUREUS:
if noisy: log("Implements Azureus extensions")
if ord(FLAGS[0]) & AZUREUS:
self.uses_azureus_extension = True
if ord(self._message[5]) & UTORRENT:
if noisy: log("Implements uTorrent extensions")
if ord(FLAGS[5]) & UTORRENT:
self.uses_utorrent_extension = True
if ord(self._message[7]) & DHT:
if noisy: log("Implements DHT")
if ord(FLAGS[7]) & DHT:
self.uses_dht = True
if ord(self._message[7]) & FAST_EXTENSION:
if noisy: log("Implements FAST_EXTENSION")
if not disable_fast_extension:
self.uses_fast_extension = True
if ord(self._message[7]) & NAT_TRAVERSAL:
if noisy: log("Implements NAT_TRAVERSAL")
if ord(FLAGS[7]) & NAT_TRAVERSAL:
self.uses_nat_traversal = True
yield 20 # download id (i.e., infohash)
if self.parent.infohash is None: # incoming connection
# modifies self.parent if successful
self.parent.select_torrent(self, self._message)
if self.parent.infohash is None:
self.protocol_violation("no infohash from parent (peer from a "
"torrent you're not running)")
return
elif self._message != self.parent.infohash:
self.protocol_violation("incorrect infohash from parent")
return
if not self.locally_initiated:
self.connection.write(''.join((chr(len(protocol_name)),
protocol_name, FLAGS,
self.parent.infohash,
self.parent.my_id)))
yield 20 # peer id
if noisy: log("peer id: %r" % self._message)
# if we don't already have the peer's id, send ours
if not self.id:
self.id = self._message
ns = (self.log_prefix + '.' + repr(self.parent.infohash) +
'.' + self._message.encode('hex'))
self.logger = logging.getLogger(ns)
if self.id == self.parent.my_id:
self.protocol_violation("talking to self")
return
if self.id in self.parent.connector_ids:
self.protocol_violation("duplicate connection (id collision)")
return
if (self.parent.config['one_connection_per_ip'] and
self.ip in self.parent.connector_ips):
self.protocol_violation("duplicate connection (ip collision)")
return
if self.locally_initiated:
self.connection.write(self.parent.my_id)
else:
self.parent.everinc = True
else:
# assert the id we have and the one we got are the same
if self._message != self.id:
self.protocol_violation("incorrect id have:%r got:%r" % (self.id, self._message))
return
self.complete = True
self.parent.connection_handshake_completed(self)
if self.uses_utorrent_extension:
response = {'m':{'ut_pex':1},
'v': '\xb5Torrent 1.5',
'e': 0,
'p': self.parent.reported_port,
}
response = bencode(response)
self._send_message(UTORRENT_MSG,
UTORRENT_MSG_INFO, response)
message_count = 0
while True:
yield 4 # message length
l = toint(self._message)
if l > self.max_message_length:
d = '%s%s' % (self._message, self._rest)
d = d[:10]
self.protocol_violation("message length exceeds max "
"(%s > %s): %r, count:%d" %
(l, self.max_message_length, d,
message_count))
return
if l > 0:
yield l
self._got_message(self._message)
message_count += 1
def _got_utorrent_msg(self, msg_type, d):
if msg_type == UTORRENT_MSG_INFO:
version = d.get('v')
port = d.get('p')
if port:
self.listening_port = int(port)
encryption = d.get('e')
messages = d.get('m')
self.uses_utorrent_pex = bool(messages.get('ut_pex', 0))
elif msg_type == UTORRENT_MSG_PEX:
for addr in IPTools.uncompact_sequence(d['added']):
self.remote_pex_set.add(addr)
self.parent.start_connection(addr)
dropped_gen = IPTools.uncompact_sequence(d['dropped'])
self.remote_pex_set.difference_update(dropped_gen)
def _got_azureus_msg(self, msg_type, d):
port = d.get('tcp_port')
if port:
self.listening_port = int(port)
m = d.get('messages', [])
for msg in m:
if msg.get('id') == 'AZ_PEER_EXCHANGE':
self.uses_azureus_pex = True
def _got_holepunch_msg(self, d):
msg_type = d.get('t')
if msg_type == 'r': # request
print 'hole punch requested from', self.addr, 'to', d['p']
d = {'t': 'i'}
d['p'] = IPTools.compact(addr)
self._send_message(HOLE_PUNCH, d)
elif msg_type == 'i': # initiate
print 'told to initiate connection(s) to:' + str(d['p'])
else:
self.protocol_violation("unknown hole punch msg type: %r" %
msg_type)
def _got_message(self, message):
t = message[0]
if t in [BITFIELD, HAVE_ALL, HAVE_NONE] and self.got_anything:
self.protocol_violation("%s after got anything" % message_dict[t])
self.close()
return
if t == UTORRENT_MSG and self.uses_utorrent_extension:
msg_type = message[1]
d = bdecode(message[2:])
if noisy: log("UTORRENT_MSG: %r:%r" % (msg_type, d))
self._got_utorrent_msg(msg_type, d)
return
if t == AZUREUS_SUCKS and self.uses_azureus_extension:
magic_intro = 17
msg_type = message[:magic_intro]
d = bdecode(message[magic_intro:])
if noisy: log("AZUREUS_MSG: %r:%r" % (msg_type, d))
self._got_azureus_msg(msg_type, d)
return
if t == HOLE_PUNCH and self.uses_nat_traversal:
d = ebdecode(message)
if noisy: log("HOLE_PUNCH: %r" % d)
self._got_holepunch_msg(d)
return
self.got_anything = True
if (t in (CHOKE, UNCHOKE, INTERESTED, NOT_INTERESTED,
HAVE_ALL, HAVE_NONE) and
len(message) != 1):
self.protocol_violation("%s with message length %d" %
(message_dict[t], len(message)))
if noisy: log("UNKNOWN: %r" % message)
self.close()
return
if t == CHOKE:
if noisy: log("GOT %s" % message_dict[t])
self.download.got_choke()
elif t == UNCHOKE:
if noisy: log("GOT %s" % message_dict[t])
self.download.got_unchoke()
elif t == INTERESTED:
if noisy: log("GOT %s" % message_dict[t])
self.upload.got_interested()
elif t == NOT_INTERESTED:
if noisy: log("GOT %s" % message_dict[t])
self.upload.got_not_interested()
elif t == HAVE:
if len(message) != 5:
self.protocol_violation("HAVE length: %d != 5" %
len(message))
self.close()
return
i = unpack("!xi", message)[0]
if noisy: log("GOT HAVE %d" % i)
if i >= self.parent.numpieces:
self.protocol_violation("HAVE %d >= %d" %
(i, self.parent.numpieces))
self.close()
return
self.download.got_have(i)
elif t == BITFIELD:
try:
b = Bitfield(self.parent.numpieces, message[1:])
except ValueError, e:
self.protocol_violation("BITFIELD %s" %
(e,))
self.close()
return
self.download.got_have_bitfield(b)
elif t == REQUEST:
if len(message) != 13:
self.protocol_violation("REQUEST length %d != 13" %
len(message))
self.close()
return
i, a, b = unpack("!xiii", message)
if noisy: log("GOT REQUEST %d %d %d" % (i, a, b))
if i >= self.parent.numpieces:
self.protocol_violation(
"Requested piece index out of range: %d > %d" %
(i, self.parent.numpieces))
self.close()
return
if a + b > self.parent.piece_size:
self.protocol_violation(
"Requested range exceeds piece size: "
"(b:%d + l:%d == %d) > %d" %
(a, b, a + b, self.parent.piece_size))
self.close()
return
if self.download.have[i]:
self.protocol_violation(
"Requested piece index %d which the peer already has" %
(i,))
self.close()
return
self.upload.got_request(i, a, b)
elif t == CANCEL:
if len(message) != 13:
self.protocol_violation("CANCEL length %d != 13" %
len(message))
self.close()
return
i, a, b = unpack("!xiii", message)
if noisy: log("GOT CANCEL %d %d %d" % (i, a, b))
if i >= self.parent.numpieces:
self.protocol_violation(
"Cancelled piece index %d > numpieces which is %d" %
(i,self.parent.numpieces))
self.close()
return
self.upload.got_cancel(i, a, b)
elif t == PIECE:
if len(message) <= 9:
self.protocol_violation("PIECE %d <= 9" %
len(message))
self.close()
return
n = len(message) - 9
i, a, b = unpack("!xii%ss" % n, message)
if noisy: log("GOT PIECE %d %d" % (i, a))
if i >= self.parent.numpieces:
self.protocol_violation("PIECE %d >= %d" %
(i, self.parent.numpieces))
self.close()
return
self.download.got_piece(i, a, b)
elif t == PORT:
if len(message) != 3:
self.protocol_violation("PORT %d != 3" %
len(message))
self.close()
return
self.dht_port = unpack('!H', message[1:3])[0]
self.parent.got_port(self)
elif t == SUGGEST_PIECE:
if not self.uses_fast_extension:
self.protocol_violation(
"Received 'SUGGEST_PIECE' when fast extension disabled.")
self.close()
return
if len(message) != 5:
self.protocol_violation("SUGGEST_PIECE length: %d != 5" %
len(message))
self.close()
return
i = unpack("!xi", message)[0]
if noisy: log("GOT SUGGEST_PIECE %d" % i)
if i >= self.parent.numpieces:
self.protocol_violation(
"Received 'SUGGEST_PIECE' with piece id %d > numpieces." %
self.parent.numpieces)
self.close()
return
self.download.got_suggest_piece(i)
elif t == HAVE_ALL:
if noisy: log("GOT %s" % message_dict[t])
if not self.uses_fast_extension:
self.protocol_violation(
"Received 'HAVE_ALL' when fast extension disabled.")
self.close()
return
self.download.got_have_all()
elif t == HAVE_NONE:
if noisy: log("GOT %s" % message_dict[t])
if not self.uses_fast_extension:
self.protocol_violation(
"Received 'HAVE_NONE' when fast extension disabled.")
self.close()
return
self.download.got_have_none()
elif t == REJECT_REQUEST:
if not self.uses_fast_extension:
self.protocol_violation(
"Received 'REJECT_REQUEST' when fast extension disabled.")
self.close()
return
if len(message) != 13:
self.protocol_violation(
"Received 'REJECT_REQUEST' with length %d != 13." %
len(message))
self.close()
return
i, a, b = unpack("!xiii", message)
if noisy: log("GOT REJECT_REQUEST %d %d" % (i,a))
if i >= self.parent.numpieces:
self.protocol_violation("REJECT %d >= %d" %
(i, self.parent.numpieces))
self.close()
return
self.download.got_reject_request(i, a, b)
elif t == ALLOWED_FAST:
if not self.uses_fast_extension:
self.protocol_violation(
"Received 'ALLOWED_FAST' when fast extension disabled.")
self.close()
return
if len(message) != 5:
self.protocol_violation("ALLOWED_FAST length: %d != 5" %
len(message))
self.close()
return
i = unpack("!xi", message)[0]
if noisy: log("GOT ALLOWED_FAST %d" % i)
self.download.got_allowed_fast(i)
else:
if noisy: log("GOT %s length %d" % (message_dict[t], len(message)))
self.protocol_violation("unhandled message %s" % message_dict[t])
self.close()
def _send_message(self, *msg_a):
if self.closed:
return
l = 0
for e in msg_a:
l += len(e)
d = [tobinary(l), ]
d.extend(msg_a)
s = ''.join(d)
if self._partial_message is not None:
self._outqueue.write(s)
else:
self.connection.write(s)
def data_came_in(self, conn, s):
self.received_data = True
if not self.download:
# this is really annoying.
self.sloppy_pre_connection_counter += len(s)
else:
l = self.sloppy_pre_connection_counter + len(s)
self.sloppy_pre_connection_counter = 0
if log_data:
assert self.addr == (conn.ip, conn.port)
open('%s_%d.log' % self.addr, 'ab').write(s)
while True:
if self.closed:
return
i = self._next_len - self._buffer.tell()
if i > len(s):
# not enough bytes, keep buffering
self._buffer.write(s)
return
if self._buffer.tell() > 0:
# collect buffer + current for message
self._buffer.write(buffer(s, 0, i))
m = self._buffer.getvalue()
# optimize for cpu (reduce mallocs)
#self._buffer.truncate(0)
# optimize for memory (free buffer memory)
self._buffer.close()
self._buffer = StringIO()
else:
# painful string copy
m = s[:i]
s = buffer(s, i)
if self._decrypt is not None:
m = self._decrypt(m)
self._message = m
self._rest = s
try:
self._next_len = self._reader.next()
except StopIteration:
self.close()
return
except:
self.logger.exception("Message parsing failed")
self.close()
return
def _optional_restart(self):
if self.locally_initiated and not self.received_data and not self.obfuscate_outgoing:
self.parent.start_connection(self.addr, id=None, encrypt=True)
def connection_lost(self, conn):
assert conn is self.connection
self.closed = True
self._reader = None
self.parent.connection_lost(self)
self._optional_restart()
self.connection = None
if self.complete:
if self.download is not None:
self.download.disconnected()
self.upload = None
self.download = None
def connection_flushed(self, connection):
if (self.complete and self.next_upload is None and
(self._partial_message is not None
or (self.upload and self.upload.buffer))):
if self.lan:
# bypass upload rate limiter
self.send_partial(self.parent.ratelimiter.unitsize)
else:
self.parent.ratelimiter.queue(self)