blob: 15946a660ffef65d75d3617abc4342f3ab849cd7 [file] [log] [blame]
Guido van Rossum57cd21f2003-04-29 10:23:27 +00001"""RPC Implemention, originally written for the Python Idle IDE
2
3For security reasons, GvR requested that Idle's Python execution server process
4connect to the Idle process, which listens for the connection. Since Idle has
5has only one client per server, this was not a limitation.
6
7 +---------------------------------+ +-------------+
8 | SocketServer.BaseRequestHandler | | SocketIO |
9 +---------------------------------+ +-------------+
10 ^ | register() |
11 | | unregister()|
12 | +-------------+
13 | ^ ^
14 | | |
15 | + -------------------+ |
16 | | |
17 +-------------------------+ +-----------------+
18 | RPCHandler | | RPCClient |
19 | [attribute of RPCServer]| | |
20 +-------------------------+ +-----------------+
21
22The RPCServer handler class is expected to provide register/unregister methods.
23RPCHandler inherits the mix-in class SocketIO, which provides these methods.
24
25See the Idle run.main() docstring for further information on how this was
26accomplished in Idle.
27
28"""
29
30import sys
31import socket
32import select
33import SocketServer
34import struct
35import cPickle as pickle
36import threading
37import traceback
38import copy_reg
39import types
40import marshal
41
42def unpickle_code(ms):
43 co = marshal.loads(ms)
44 assert isinstance(co, types.CodeType)
45 return co
46
47def pickle_code(co):
48 assert isinstance(co, types.CodeType)
49 ms = marshal.dumps(co)
50 return unpickle_code, (ms,)
51
52# XXX KBK 24Aug02 function pickling capability not used in Idle
53# def unpickle_function(ms):
54# return ms
55
56# def pickle_function(fn):
57# assert isinstance(fn, type.FunctionType)
58# return `fn`
59
60copy_reg.pickle(types.CodeType, pickle_code, unpickle_code)
61# copy_reg.pickle(types.FunctionType, pickle_function, unpickle_function)
62
63BUFSIZE = 8*1024
64
65class RPCServer(SocketServer.TCPServer):
66
67 def __init__(self, addr, handlerclass=None):
68 if handlerclass is None:
69 handlerclass = RPCHandler
70 SocketServer.TCPServer.__init__(self, addr, handlerclass)
71
72 def server_bind(self):
73 "Override TCPServer method, no bind() phase for connecting entity"
74 pass
75
76 def server_activate(self):
77 """Override TCPServer method, connect() instead of listen()
78
79 Due to the reversed connection, self.server_address is actually the
80 address of the Idle Client to which we are connecting.
81
82 """
83 self.socket.connect(self.server_address)
84
85 def get_request(self):
86 "Override TCPServer method, return already connected socket"
87 return self.socket, self.server_address
88
89 def handle_error(self, request, client_address):
90 """Override TCPServer method
91
92 Error message goes to __stderr__. No error message if exiting
93 normally or socket raised EOF. Other exceptions not handled in
94 server code will cause os._exit.
95
96 """
97 try:
98 raise
99 except SystemExit:
100 raise
101 except EOFError:
102 pass
103 except:
104 erf = sys.__stderr__
105 print>>erf, '\n' + '-'*40
106 print>>erf, 'Unhandled server exception!'
107 print>>erf, 'Thread: %s' % threading.currentThread().getName()
108 print>>erf, 'Client Address: ', client_address
109 print>>erf, 'Request: ', repr(request)
110 traceback.print_exc(file=erf)
111 print>>erf, '\n*** Unrecoverable, server exiting!'
112 print>>erf, '-'*40
113 import os
114 os._exit(0)
115
116
117objecttable = {}
118
119class SocketIO:
120
121 nextseq = 0
122
123 def __init__(self, sock, objtable=None, debugging=None):
124 self.mainthread = threading.currentThread()
125 if debugging is not None:
126 self.debugging = debugging
127 self.sock = sock
128 if objtable is None:
129 objtable = objecttable
130 self.objtable = objtable
131 self.cvar = threading.Condition()
132 self.responses = {}
133 self.cvars = {}
134 self.interrupted = False
135
136 def close(self):
137 sock = self.sock
138 self.sock = None
139 if sock is not None:
140 sock.close()
141
142 def debug(self, *args):
143 if not self.debugging:
144 return
145 s = self.location + " " + str(threading.currentThread().getName())
146 for a in args:
147 s = s + " " + str(a)
148 print>>sys.__stderr__, s
149
150 def register(self, oid, object):
151 self.objtable[oid] = object
152
153 def unregister(self, oid):
154 try:
155 del self.objtable[oid]
156 except KeyError:
157 pass
158
159 def localcall(self, request):
160 self.debug("localcall:", request)
161 try:
162 how, (oid, methodname, args, kwargs) = request
163 except TypeError:
164 return ("ERROR", "Bad request format")
165 assert how == "call"
166 if not self.objtable.has_key(oid):
167 return ("ERROR", "Unknown object id: %s" % `oid`)
168 obj = self.objtable[oid]
169 if methodname == "__methods__":
170 methods = {}
171 _getmethods(obj, methods)
172 return ("OK", methods)
173 if methodname == "__attributes__":
174 attributes = {}
175 _getattributes(obj, attributes)
176 return ("OK", attributes)
177 if not hasattr(obj, methodname):
178 return ("ERROR", "Unsupported method name: %s" % `methodname`)
179 method = getattr(obj, methodname)
180 try:
181 ret = method(*args, **kwargs)
182 if isinstance(ret, RemoteObject):
183 ret = remoteref(ret)
184 return ("OK", ret)
185 except SystemExit:
186 raise
187 except socket.error:
188 pass
189 except:
190 self.debug("localcall:EXCEPTION")
191 traceback.print_exc(file=sys.__stderr__)
192 return ("EXCEPTION", None)
193
194 def remotecall(self, oid, methodname, args, kwargs):
195 self.debug("remotecall:asynccall: ", oid, methodname)
196 # XXX KBK 06Feb03 self.interrupted logic may not be necessary if
197 # subprocess is threaded.
198 if self.interrupted:
199 self.interrupted = False
200 raise KeyboardInterrupt
201 seq = self.asynccall(oid, methodname, args, kwargs)
202 return self.asyncreturn(seq)
203
204 def asynccall(self, oid, methodname, args, kwargs):
205 request = ("call", (oid, methodname, args, kwargs))
206 seq = self.newseq()
207 self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
208 self.putmessage((seq, request))
209 return seq
210
211 def asyncreturn(self, seq):
212 self.debug("asyncreturn:%d:call getresponse(): " % seq)
213 response = self.getresponse(seq, wait=None)
214 self.debug(("asyncreturn:%d:response: " % seq), response)
215 return self.decoderesponse(response)
216
217 def decoderesponse(self, response):
218 how, what = response
219 if how == "OK":
220 return what
221 if how == "EXCEPTION":
222 self.debug("decoderesponse: EXCEPTION")
223 return None
224 if how == "ERROR":
225 self.debug("decoderesponse: Internal ERROR:", what)
226 raise RuntimeError, what
227 raise SystemError, (how, what)
228
229 def mainloop(self):
230 """Listen on socket until I/O not ready or EOF
231
232 Main thread pollresponse() will loop looking for seq number None, which
233 never comes, and exit on EOFError.
234
235 """
236 try:
237 self.getresponse(myseq=None, wait=None)
238 except EOFError:
239 pass
240
241 def getresponse(self, myseq, wait):
242 response = self._getresponse(myseq, wait)
243 if response is not None:
244 how, what = response
245 if how == "OK":
246 response = how, self._proxify(what)
247 return response
248
249 def _proxify(self, obj):
250 if isinstance(obj, RemoteProxy):
251 return RPCProxy(self, obj.oid)
252 if isinstance(obj, types.ListType):
253 return map(self._proxify, obj)
254 # XXX Check for other types -- not currently needed
255 return obj
256
257 def _getresponse(self, myseq, wait):
258 self.debug("_getresponse:myseq:", myseq)
259 if threading.currentThread() is self.mainthread:
260 # Main thread: does all reading of requests or responses
261 # Loop here, blocking each time until socket is ready.
262 while 1:
263 response = self.pollresponse(myseq, wait)
264 if response is not None:
265 return response
266 else:
267 # Auxiliary thread: wait for notification from main thread
268 self.cvar.acquire()
269 self.cvars[myseq] = self.cvar
270 while not self.responses.has_key(myseq):
271 self.cvar.wait()
272 response = self.responses[myseq]
273 del self.responses[myseq]
274 del self.cvars[myseq]
275 self.cvar.release()
276 return response
277
278 def newseq(self):
279 self.nextseq = seq = self.nextseq + 2
280 return seq
281
282 def putmessage(self, message):
283 self.debug("putmessage:%d:" % message[0])
284 try:
285 s = pickle.dumps(message)
286 except:
287 print >>sys.__stderr__, "Cannot pickle:", `message`
288 raise
289 s = struct.pack("<i", len(s)) + s
290 while len(s) > 0:
291 try:
292 n = self.sock.send(s)
293 except AttributeError:
294 # socket was closed
295 raise IOError
296 else:
297 s = s[n:]
298
299 def ioready(self, wait=0.0):
300 r, w, x = select.select([self.sock.fileno()], [], [], wait)
301 return len(r)
302
303 buffer = ""
304 bufneed = 4
305 bufstate = 0 # meaning: 0 => reading count; 1 => reading data
306
307 def pollpacket(self, wait=0.0):
308 self._stage0()
309 if len(self.buffer) < self.bufneed:
310 if not self.ioready(wait):
311 return None
312 try:
313 s = self.sock.recv(BUFSIZE)
314 except socket.error:
315 raise EOFError
316 if len(s) == 0:
317 raise EOFError
318 self.buffer += s
319 self._stage0()
320 return self._stage1()
321
322 def _stage0(self):
323 if self.bufstate == 0 and len(self.buffer) >= 4:
324 s = self.buffer[:4]
325 self.buffer = self.buffer[4:]
326 self.bufneed = struct.unpack("<i", s)[0]
327 self.bufstate = 1
328
329 def _stage1(self):
330 if self.bufstate == 1 and len(self.buffer) >= self.bufneed:
331 packet = self.buffer[:self.bufneed]
332 self.buffer = self.buffer[self.bufneed:]
333 self.bufneed = 4
334 self.bufstate = 0
335 return packet
336
337 def pollmessage(self, wait=0.0):
338 packet = self.pollpacket(wait)
339 if packet is None:
340 return None
341 try:
342 message = pickle.loads(packet)
343 except:
344 print >>sys.__stderr__, "-----------------------"
345 print >>sys.__stderr__, "cannot unpickle packet:", `packet`
346 traceback.print_stack(file=sys.__stderr__)
347 print >>sys.__stderr__, "-----------------------"
348 raise
349 return message
350
351 def pollresponse(self, myseq, wait=0.0):
352 """Handle messages received on the socket.
353
354 Some messages received may be asynchronous 'call' commands, and
355 some may be responses intended for other threads.
356
357 Loop until message with myseq sequence number is received. Save others
358 in self.responses and notify the owning thread, except that 'call'
359 commands are handed off to localcall() and the response sent back
360 across the link with the appropriate sequence number.
361
362 """
363 while 1:
364 message = self.pollmessage(wait)
365 if message is None: # socket not ready
366 return None
367 #wait = 0.0 # poll on subsequent passes instead of blocking
368 seq, resq = message
369 self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
370 if resq[0] == "call":
371 self.debug("pollresponse:%d:localcall:call:" % seq)
372 response = self.localcall(resq)
373 self.debug("pollresponse:%d:localcall:response:%s"
374 % (seq, response))
375 self.putmessage((seq, response))
376 continue
377 elif seq == myseq:
378 return resq
379 else:
380 self.cvar.acquire()
381 cv = self.cvars.get(seq)
382 # response involving unknown sequence number is discarded,
383 # probably intended for prior incarnation
384 if cv is not None:
385 self.responses[seq] = resq
386 cv.notify()
387 self.cvar.release()
388 continue
389
390#----------------- end class SocketIO --------------------
391
392class RemoteObject:
393 # Token mix-in class
394 pass
395
396def remoteref(obj):
397 oid = id(obj)
398 objecttable[oid] = obj
399 return RemoteProxy(oid)
400
401class RemoteProxy:
402
403 def __init__(self, oid):
404 self.oid = oid
405
406class RPCHandler(SocketServer.BaseRequestHandler, SocketIO):
407
408 debugging = False
409 location = "#S" # Server
410
411 def __init__(self, sock, addr, svr):
412 svr.current_handler = self ## cgt xxx
413 SocketIO.__init__(self, sock)
414 SocketServer.BaseRequestHandler.__init__(self, sock, addr, svr)
415
416 def handle(self):
417 "handle() method required by SocketServer"
418 self.mainloop()
419
420 def get_remote_proxy(self, oid):
421 return RPCProxy(self, oid)
422
423class RPCClient(SocketIO):
424
425 debugging = False
426 location = "#C" # Client
427
428 nextseq = 1 # Requests coming from the client are odd numbered
429
430 def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
431 self.listening_sock = socket.socket(family, type)
432 self.listening_sock.setsockopt(socket.SOL_SOCKET,
433 socket.SO_REUSEADDR, 1)
434 self.listening_sock.bind(address)
435 self.listening_sock.listen(1)
436
437 def accept(self):
438 working_sock, address = self.listening_sock.accept()
439 if self.debugging:
440 print>>sys.__stderr__, "****** Connection request from ", address
441 if address[0] == '127.0.0.1':
442 SocketIO.__init__(self, working_sock)
443 else:
444 print>>sys.__stderr__, "** Invalid host: ", address
445 raise socket.error
446
447 def get_remote_proxy(self, oid):
448 return RPCProxy(self, oid)
449
450class RPCProxy:
451
452 __methods = None
453 __attributes = None
454
455 def __init__(self, sockio, oid):
456 self.sockio = sockio
457 self.oid = oid
458
459 def __getattr__(self, name):
460 if self.__methods is None:
461 self.__getmethods()
462 if self.__methods.get(name):
463 return MethodProxy(self.sockio, self.oid, name)
464 if self.__attributes is None:
465 self.__getattributes()
466 if not self.__attributes.has_key(name):
467 raise AttributeError, name
468 __getattr__.DebuggerStepThrough=1
469
470 def __getattributes(self):
471 self.__attributes = self.sockio.remotecall(self.oid,
472 "__attributes__", (), {})
473
474 def __getmethods(self):
475 self.__methods = self.sockio.remotecall(self.oid,
476 "__methods__", (), {})
477
478def _getmethods(obj, methods):
479 # Helper to get a list of methods from an object
480 # Adds names to dictionary argument 'methods'
481 for name in dir(obj):
482 attr = getattr(obj, name)
483 if callable(attr):
484 methods[name] = 1
485 if type(obj) == types.InstanceType:
486 _getmethods(obj.__class__, methods)
487 if type(obj) == types.ClassType:
488 for super in obj.__bases__:
489 _getmethods(super, methods)
490
491def _getattributes(obj, attributes):
492 for name in dir(obj):
493 attr = getattr(obj, name)
494 if not callable(attr):
495 attributes[name] = 1
496
497class MethodProxy:
498
499 def __init__(self, sockio, oid, name):
500 self.sockio = sockio
501 self.oid = oid
502 self.name = name
503
504 def __call__(self, *args, **kwargs):
505 value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
506 return value
507
508#
509# Self Test
510#
511
512def testServer(addr):
513 # XXX 25 Jul 02 KBK needs update to use rpc.py register/unregister methods
514 class RemotePerson:
515 def __init__(self,name):
516 self.name = name
517 def greet(self, name):
518 print "(someone called greet)"
519 print "Hello %s, I am %s." % (name, self.name)
520 print
521 def getName(self):
522 print "(someone called getName)"
523 print
524 return self.name
525 def greet_this_guy(self, name):
526 print "(someone called greet_this_guy)"
527 print "About to greet %s ..." % name
528 remote_guy = self.server.current_handler.get_remote_proxy(name)
529 remote_guy.greet("Thomas Edison")
530 print "Done."
531 print
532
533 person = RemotePerson("Thomas Edison")
534 svr = RPCServer(addr)
535 svr.register('thomas', person)
536 person.server = svr # only required if callbacks are used
537
538 # svr.serve_forever()
539 svr.handle_request() # process once only
540
541def testClient(addr):
542 "demonstrates RPC Client"
543 # XXX 25 Jul 02 KBK needs update to use rpc.py register/unregister methods
544 import time
545 clt=RPCClient(addr)
546 thomas = clt.get_remote_proxy("thomas")
547 print "The remote person's name is ..."
548 print thomas.getName()
549 # print clt.remotecall("thomas", "getName", (), {})
550 print
551 time.sleep(1)
552 print "Getting remote thomas to say hi..."
553 thomas.greet("Alexander Bell")
554 #clt.remotecall("thomas","greet",("Alexander Bell",), {})
555 print "Done."
556 print
557 time.sleep(2)
558 # demonstrates remote server calling local instance
559 class LocalPerson:
560 def __init__(self,name):
561 self.name = name
562 def greet(self, name):
563 print "You've greeted me!"
564 def getName(self):
565 return self.name
566 person = LocalPerson("Alexander Bell")
567 clt.register("alexander",person)
568 thomas.greet_this_guy("alexander")
569 # clt.remotecall("thomas","greet_this_guy",("alexander",), {})
570
571def test():
572 addr=("localhost",8833)
573 if len(sys.argv) == 2:
574 if sys.argv[1]=='-server':
575 testServer(addr)
576 return
577 testClient(addr)
578
579if __name__ == '__main__':
580 test()