#############################################################################
#
#	protocol.py - Pyro Protocol Adapters
#
#	This is part of "Pyro" - Python Remote Objects
#	which is (c) Irmen de Jong - irmen@bigfoot.com.
#
#############################################################################

import select, socket, struct
import Pyro
from Pyro.util import pickle, Log
from Pyro.errors import *

#------ Get the hostname (and cache it)
_hostname=None
def getHostname():
	global _hostname
	if _hostname:
		return _hostname
	else:
		_hostname=socket.gethostname()
		return _hostname

#------ Get our IP address (and cache it) (return None on error)
_ip=None
def getIPAddress():
	global _ip
	if _ip:
		return _ip
	try:
		_ip=socket.gethostbyname(getHostname())
		return _ip
	except socket.error:
		_ip=None
		return _ip


#------ Socket helper functions for sending and receiving data correctly.

# Receive a precise number of bytes from a socket. Raises the
# SocketClosedError if  that number of bytes was not available.
# (the connection has probably been closed then).
# Never will this function return an empty message (if size>0).
# We need this because 'recv' isn't guaranteed to return all desired
# bytes in one call, for instance, when network load is high.
def sock_recvmsg(sock, size):
	msg=''
	while len(msg)<size:
		chunk=sock.recv(size-len(msg))
		if not chunk:
			raise SocketClosedError('connection lost')
		msg=msg+chunk
	return msg

# Send a message over a socket. Raises SocketClosedError if the msg
# couldn't be sent (the connection has probably been lost then).
# We need this because 'send' isn't guaranteed to send all desired
# bytes in one call, for instance, when network load is high.
def sock_sendmsg(sock, msg):
	size=len(msg)
	total=0
	while total<size:
		sent=sock.send(msg[total:])
		if sent==0:
			raise SocketClosedError('connection lost')
		total=total+sent



#------ PYRO: adapter (default Pyro wire protocol)
#------ This adapter is for protocol version 1.0 ONLY;
#------ it accepts 1.0-style headers exclusively.
#------ Future adapters should be downwards compatible and more flexible.

class PYROAdapter:
	def __init__(self):
		self.daemon = None			# the Pyro daemon which uses us
		self.headerFmt = '!4sBBHL'	# version 1.0 header
		self.headerID = 'PYRO'
		self.acceptMSG= 'GRANTED'	# must be same length as denyMSG
		self.denyMSG=   'DENIED!'	# must be same length as acceptMSG
		self.headerSize = struct.calcsize(self.headerFmt)
		self.version, self.revision = (1,0)	# version 1.0
	def setDaemon(self, daemon):
		self.daemon = daemon
		Log.msg('PYROAdapter','adapter daemon set to',str(daemon))
	def sendAccept(self, sock):		# called by TCPServer
		sock_sendmsg(sock, self.acceptMSG)
	def sendDeny(self, sock):		# called by TCPServer
		sock_sendmsg(sock, self.denyMSG)
	def bindToURI(self,URI):
		if URI.protocol!='PYRO':
			Log.error('PYROAdapter','incompatible protocol in URI:',URI.protocol)
			raise ProtocolError('incompatible protocol in URI')
		try:
			self.host, self.port, self.objectID = (None,None,None)
			self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
			self.sock.connect(URI.host, URI.port)
			msg = sock_recvmsg(self.sock, len(self.acceptMSG))
			if msg==self.acceptMSG:
				self.host = URI.host
				self.port = URI.port
				self.objectID = URI.objectID
				Log.msg('PYROAdapter','connected to',str(URI))
			elif msg==self.denyMSG:
				raise ServerFullError,('connection denied, too busy')
		except socket.error:
			Log.msg('PYROAdapter','connection failed to URI',str(URI))
			raise ProtocolError('connection failed')

	def createMsg(self, body):
		return struct.pack(self.headerFmt, self.headerID,
					self.version, self.revision,
					self.headerSize, len(body)) + body

	def remoteInvocation(self, method, flags, *args):
		body=pickle.dumps((self.objectID,method,flags,args),Pyro.config.PYRO_BINARY_PICKLE)
		sock_sendmsg(self.sock, self.createMsg(body))
		answer = pickle.loads(self.receiveMsg(self.sock))
		if isinstance(answer,PyroExceptionCapsule):
			# we have an encapsulated exception, raise it again.
			answer.raiseEx()
		return answer

	def receiveMsg(self,sock):
		msg=sock_recvmsg(sock, self.headerSize)
		(id, ver, rev, hsiz, bsiz) = struct.unpack(self.headerFmt,msg)
		if id!=self.headerID or hsiz!=self.headerSize:
			Log.error('PYROAdapter','invalid header')
			raise ProtocolError('invalid header')
		if ver!=self.version or rev!=self.revision:
			Log.error('PYROAdapter','incompatible version')
			raise ProtocolError('incompatible version')
		body=sock_recvmsg(sock, bsiz)
		return body

	def handleInvocation(self,sock):
		body = self.receiveMsg(sock)
		# Unpickle the request, which is a tuple:
		#  (object ID, method name, flags, (arg1,arg2,...))
		req=pickle.loads(body)
		try:
			# find the object in the implementation database of our daemon
			o = self.daemon.implementations[req[0]]
		except KeyError,x:
			Log.warn('PYROAdapter','Invocation to unknown object ignored:',x)
		else:
			# call the method on this object
			res = o[0].Pyro_dyncall(req[1],req[2],req[3])	# (method,flags,args)
			# reply the result to the caller
			body=pickle.dumps(res,Pyro.config.PYRO_BINARY_PICKLE)
			sock_sendmsg(sock, self.createMsg(body))

	def returnException(self, sock, exc):
		# return an encapsulated exception to the client
		body=pickle.dumps(PyroExceptionCapsule(exc),Pyro.config.PYRO_BINARY_PICKLE)
		sock_sendmsg(sock, self.createMsg(body))
	
		
def getProtocolAdapter(protocol):
	if protocol=='PYRO':
		return PYROAdapter()
	else:
		Log.error('getProtocolAdapter','unsupported protocol:',protocol)
		raise ProtocolError('unsupported protocol')


#-------- TCPConnection object for TCPServer class
class TCPConnection:
	def __init__(self, sock, addr):
		self.sock = sock
		self.addr = addr
	def fileno(self):
		return self.sock.fileno()

#-------- TCPServer base class
class TCPServer:
	def __init__(self, requestServer, port):
		self.slave = requestServer
		self.slave.daemon=self
		self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		self.sock.bind('',port)
		self.sock.listen(5)
		self.connections = []
		self.setParamsForLoop(5)
	def __del__(self):
		if len(self.connections)>0:
			Log.warn('TCPServer','Shutting down but there are still',len(self.connections),'active connections')
			
	def setParamsForLoop(self, timeout, others=[], callback=None):
		self.timeout=timeout
		self.others=others
		self.callback=callback
	def connectionLoop(self, condition=lambda:1):
		while condition():
			self.handleRequests(self.timeout,self.others,self.callback)
	def newConnection(self,sock):
		csock, addr = sock.accept()
		conn=TCPConnection(csock,addr)
		if len(self.connections)<Pyro.config.PYRO_MAXCONNECTIONS:
			self.adapter.sendAccept(csock)
			self.connections.append(conn)
			Log.msg('TCPServer','new connection from',addr,'#conns=',len(self.connections))
		else:
			# we have too many open connections. Disconnect this one.
			Log.msg('TCPServer','Too many open connections, closing',addr,'#conns=',len(self.connections))
			self.adapter.sendDeny(csock)

	def handleRequests(self, timeout=None, others=[], callback=None):
		activecnt=1	
		while activecnt:
			socklist = self.connections+[self.sock]+others
			if timeout==None:
				ins,outs,exs = select.select(socklist,[],[])
			else:
				ins,outs,exs = select.select(socklist,[],[],timeout)
			activecnt=len(ins)
			if self.sock in ins:
				self.newConnection(self.sock)
				ins.remove(self.sock)
			for c in ins[0:]:
				if isinstance(c,TCPConnection):
					try:
						self.slave.handleRequest(c)
					except:
						self.handleError(c)
					ins.remove(c)
			if ins and callback:
				# the 'others' must have fired...
				callback(ins)

	def handleError(self,conn):
		# Default implementation. Most likely you have to override
		# this in your subclass.
		print '-'*40
		print 'Exception happened during processing of request from',conn.addr
		import traceback
		traceback.print_exc()
		print '-'*40

	# to be called by slave class if it detects a dropped connection:
	def removeConnection(self, conn):
		if conn in self.connections:
			self.connections.remove(conn)
			Log.msg('TCPServer','removed connection with',conn.addr,' #conns=',len(self.connections))

