blob: 0dc4ddc2de01cb0e130b0f569e13f16e1c21646f [file] [log] [blame]
Colin Cross7bb052a2015-02-03 12:59:37 -08001// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package rpc
6
7import (
8 "errors"
9 "fmt"
10 "io"
11 "log"
12 "net"
13 "net/http/httptest"
14 "runtime"
15 "strings"
16 "sync"
17 "sync/atomic"
18 "testing"
19 "time"
20)
21
22var (
23 newServer *Server
24 serverAddr, newServerAddr string
25 httpServerAddr string
26 once, newOnce, httpOnce sync.Once
27)
28
29const (
30 newHttpPath = "/foo"
31)
32
33type Args struct {
34 A, B int
35}
36
37type Reply struct {
38 C int
39}
40
41type Arith int
42
43// Some of Arith's methods have value args, some have pointer args. That's deliberate.
44
45func (t *Arith) Add(args Args, reply *Reply) error {
46 reply.C = args.A + args.B
47 return nil
48}
49
50func (t *Arith) Mul(args *Args, reply *Reply) error {
51 reply.C = args.A * args.B
52 return nil
53}
54
55func (t *Arith) Div(args Args, reply *Reply) error {
56 if args.B == 0 {
57 return errors.New("divide by zero")
58 }
59 reply.C = args.A / args.B
60 return nil
61}
62
63func (t *Arith) String(args *Args, reply *string) error {
64 *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
65 return nil
66}
67
68func (t *Arith) Scan(args string, reply *Reply) (err error) {
69 _, err = fmt.Sscan(args, &reply.C)
70 return
71}
72
73func (t *Arith) Error(args *Args, reply *Reply) error {
74 panic("ERROR")
75}
76
77func listenTCP() (net.Listener, string) {
78 l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
79 if e != nil {
80 log.Fatalf("net.Listen tcp :0: %v", e)
81 }
82 return l, l.Addr().String()
83}
84
85func startServer() {
86 Register(new(Arith))
87 RegisterName("net.rpc.Arith", new(Arith))
88
89 var l net.Listener
90 l, serverAddr = listenTCP()
91 log.Println("Test RPC server listening on", serverAddr)
92 go Accept(l)
93
94 HandleHTTP()
95 httpOnce.Do(startHttpServer)
96}
97
98func startNewServer() {
99 newServer = NewServer()
100 newServer.Register(new(Arith))
101 newServer.RegisterName("net.rpc.Arith", new(Arith))
102 newServer.RegisterName("newServer.Arith", new(Arith))
103
104 var l net.Listener
105 l, newServerAddr = listenTCP()
106 log.Println("NewServer test RPC server listening on", newServerAddr)
107 go newServer.Accept(l)
108
109 newServer.HandleHTTP(newHttpPath, "/bar")
110 httpOnce.Do(startHttpServer)
111}
112
113func startHttpServer() {
114 server := httptest.NewServer(nil)
115 httpServerAddr = server.Listener.Addr().String()
116 log.Println("Test HTTP RPC server listening on", httpServerAddr)
117}
118
119func TestRPC(t *testing.T) {
120 once.Do(startServer)
121 testRPC(t, serverAddr)
122 newOnce.Do(startNewServer)
123 testRPC(t, newServerAddr)
124 testNewServerRPC(t, newServerAddr)
125}
126
127func testRPC(t *testing.T, addr string) {
128 client, err := Dial("tcp", addr)
129 if err != nil {
130 t.Fatal("dialing", err)
131 }
132 defer client.Close()
133
134 // Synchronous calls
135 args := &Args{7, 8}
136 reply := new(Reply)
137 err = client.Call("Arith.Add", args, reply)
138 if err != nil {
139 t.Errorf("Add: expected no error but got string %q", err.Error())
140 }
141 if reply.C != args.A+args.B {
142 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
143 }
144
145 // Nonexistent method
146 args = &Args{7, 0}
147 reply = new(Reply)
148 err = client.Call("Arith.BadOperation", args, reply)
149 // expect an error
150 if err == nil {
151 t.Error("BadOperation: expected error")
152 } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
153 t.Errorf("BadOperation: expected can't find method error; got %q", err)
154 }
155
156 // Unknown service
157 args = &Args{7, 8}
158 reply = new(Reply)
159 err = client.Call("Arith.Unknown", args, reply)
160 if err == nil {
161 t.Error("expected error calling unknown service")
162 } else if strings.Index(err.Error(), "method") < 0 {
163 t.Error("expected error about method; got", err)
164 }
165
166 // Out of order.
167 args = &Args{7, 8}
168 mulReply := new(Reply)
169 mulCall := client.Go("Arith.Mul", args, mulReply, nil)
170 addReply := new(Reply)
171 addCall := client.Go("Arith.Add", args, addReply, nil)
172
173 addCall = <-addCall.Done
174 if addCall.Error != nil {
175 t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
176 }
177 if addReply.C != args.A+args.B {
178 t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
179 }
180
181 mulCall = <-mulCall.Done
182 if mulCall.Error != nil {
183 t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
184 }
185 if mulReply.C != args.A*args.B {
186 t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
187 }
188
189 // Error test
190 args = &Args{7, 0}
191 reply = new(Reply)
192 err = client.Call("Arith.Div", args, reply)
193 // expect an error: zero divide
194 if err == nil {
195 t.Error("Div: expected error")
196 } else if err.Error() != "divide by zero" {
197 t.Error("Div: expected divide by zero error; got", err)
198 }
199
200 // Bad type.
201 reply = new(Reply)
202 err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
203 if err == nil {
204 t.Error("expected error calling Arith.Add with wrong arg type")
205 } else if strings.Index(err.Error(), "type") < 0 {
206 t.Error("expected error about type; got", err)
207 }
208
209 // Non-struct argument
210 const Val = 12345
211 str := fmt.Sprint(Val)
212 reply = new(Reply)
213 err = client.Call("Arith.Scan", &str, reply)
214 if err != nil {
215 t.Errorf("Scan: expected no error but got string %q", err.Error())
216 } else if reply.C != Val {
217 t.Errorf("Scan: expected %d got %d", Val, reply.C)
218 }
219
220 // Non-struct reply
221 args = &Args{27, 35}
222 str = ""
223 err = client.Call("Arith.String", args, &str)
224 if err != nil {
225 t.Errorf("String: expected no error but got string %q", err.Error())
226 }
227 expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
228 if str != expect {
229 t.Errorf("String: expected %s got %s", expect, str)
230 }
231
232 args = &Args{7, 8}
233 reply = new(Reply)
234 err = client.Call("Arith.Mul", args, reply)
235 if err != nil {
236 t.Errorf("Mul: expected no error but got string %q", err.Error())
237 }
238 if reply.C != args.A*args.B {
239 t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
240 }
241
242 // ServiceName contain "." character
243 args = &Args{7, 8}
244 reply = new(Reply)
245 err = client.Call("net.rpc.Arith.Add", args, reply)
246 if err != nil {
247 t.Errorf("Add: expected no error but got string %q", err.Error())
248 }
249 if reply.C != args.A+args.B {
250 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
251 }
252}
253
254func testNewServerRPC(t *testing.T, addr string) {
255 client, err := Dial("tcp", addr)
256 if err != nil {
257 t.Fatal("dialing", err)
258 }
259 defer client.Close()
260
261 // Synchronous calls
262 args := &Args{7, 8}
263 reply := new(Reply)
264 err = client.Call("newServer.Arith.Add", args, reply)
265 if err != nil {
266 t.Errorf("Add: expected no error but got string %q", err.Error())
267 }
268 if reply.C != args.A+args.B {
269 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
270 }
271}
272
273func TestHTTP(t *testing.T) {
274 once.Do(startServer)
275 testHTTPRPC(t, "")
276 newOnce.Do(startNewServer)
277 testHTTPRPC(t, newHttpPath)
278}
279
280func testHTTPRPC(t *testing.T, path string) {
281 var client *Client
282 var err error
283 if path == "" {
284 client, err = DialHTTP("tcp", httpServerAddr)
285 } else {
286 client, err = DialHTTPPath("tcp", httpServerAddr, path)
287 }
288 if err != nil {
289 t.Fatal("dialing", err)
290 }
291 defer client.Close()
292
293 // Synchronous calls
294 args := &Args{7, 8}
295 reply := new(Reply)
296 err = client.Call("Arith.Add", args, reply)
297 if err != nil {
298 t.Errorf("Add: expected no error but got string %q", err.Error())
299 }
300 if reply.C != args.A+args.B {
301 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
302 }
303}
304
305// CodecEmulator provides a client-like api and a ServerCodec interface.
306// Can be used to test ServeRequest.
307type CodecEmulator struct {
308 server *Server
309 serviceMethod string
310 args *Args
311 reply *Reply
312 err error
313}
314
315func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
316 codec.serviceMethod = serviceMethod
317 codec.args = args
318 codec.reply = reply
319 codec.err = nil
320 var serverError error
321 if codec.server == nil {
322 serverError = ServeRequest(codec)
323 } else {
324 serverError = codec.server.ServeRequest(codec)
325 }
326 if codec.err == nil && serverError != nil {
327 codec.err = serverError
328 }
329 return codec.err
330}
331
332func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
333 req.ServiceMethod = codec.serviceMethod
334 req.Seq = 0
335 return nil
336}
337
338func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
339 if codec.args == nil {
340 return io.ErrUnexpectedEOF
341 }
342 *(argv.(*Args)) = *codec.args
343 return nil
344}
345
346func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
347 if resp.Error != "" {
348 codec.err = errors.New(resp.Error)
349 } else {
350 *codec.reply = *(reply.(*Reply))
351 }
352 return nil
353}
354
355func (codec *CodecEmulator) Close() error {
356 return nil
357}
358
359func TestServeRequest(t *testing.T) {
360 once.Do(startServer)
361 testServeRequest(t, nil)
362 newOnce.Do(startNewServer)
363 testServeRequest(t, newServer)
364}
365
366func testServeRequest(t *testing.T, server *Server) {
367 client := CodecEmulator{server: server}
368 defer client.Close()
369
370 args := &Args{7, 8}
371 reply := new(Reply)
372 err := client.Call("Arith.Add", args, reply)
373 if err != nil {
374 t.Errorf("Add: expected no error but got string %q", err.Error())
375 }
376 if reply.C != args.A+args.B {
377 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
378 }
379
380 err = client.Call("Arith.Add", nil, reply)
381 if err == nil {
382 t.Errorf("expected error calling Arith.Add with nil arg")
383 }
384}
385
386type ReplyNotPointer int
387type ArgNotPublic int
388type ReplyNotPublic int
389type NeedsPtrType int
390type local struct{}
391
392func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
393 return nil
394}
395
396func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
397 return nil
398}
399
400func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
401 return nil
402}
403
404func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
405 return nil
406}
407
408// Check that registration handles lots of bad methods and a type with no suitable methods.
409func TestRegistrationError(t *testing.T) {
410 err := Register(new(ReplyNotPointer))
411 if err == nil {
412 t.Error("expected error registering ReplyNotPointer")
413 }
414 err = Register(new(ArgNotPublic))
415 if err == nil {
416 t.Error("expected error registering ArgNotPublic")
417 }
418 err = Register(new(ReplyNotPublic))
419 if err == nil {
420 t.Error("expected error registering ReplyNotPublic")
421 }
422 err = Register(NeedsPtrType(0))
423 if err == nil {
424 t.Error("expected error registering NeedsPtrType")
425 } else if !strings.Contains(err.Error(), "pointer") {
426 t.Error("expected hint when registering NeedsPtrType")
427 }
428}
429
430type WriteFailCodec int
431
432func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
433 // the panic caused by this error used to not unlock a lock.
434 return errors.New("fail")
435}
436
437func (WriteFailCodec) ReadResponseHeader(*Response) error {
438 select {}
439}
440
441func (WriteFailCodec) ReadResponseBody(interface{}) error {
442 select {}
443}
444
445func (WriteFailCodec) Close() error {
446 return nil
447}
448
449func TestSendDeadlock(t *testing.T) {
450 client := NewClientWithCodec(WriteFailCodec(0))
451 defer client.Close()
452
453 done := make(chan bool)
454 go func() {
455 testSendDeadlock(client)
456 testSendDeadlock(client)
457 done <- true
458 }()
459 select {
460 case <-done:
461 return
462 case <-time.After(5 * time.Second):
463 t.Fatal("deadlock")
464 }
465}
466
467func testSendDeadlock(client *Client) {
468 defer func() {
469 recover()
470 }()
471 args := &Args{7, 8}
472 reply := new(Reply)
473 client.Call("Arith.Add", args, reply)
474}
475
476func dialDirect() (*Client, error) {
477 return Dial("tcp", serverAddr)
478}
479
480func dialHTTP() (*Client, error) {
481 return DialHTTP("tcp", httpServerAddr)
482}
483
484func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
485 once.Do(startServer)
486 client, err := dial()
487 if err != nil {
488 t.Fatal("error dialing", err)
489 }
490 defer client.Close()
491
492 args := &Args{7, 8}
493 reply := new(Reply)
494 return testing.AllocsPerRun(100, func() {
495 err := client.Call("Arith.Add", args, reply)
496 if err != nil {
497 t.Errorf("Add: expected no error but got string %q", err.Error())
498 }
499 if reply.C != args.A+args.B {
500 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
501 }
502 })
503}
504
505func TestCountMallocs(t *testing.T) {
506 if testing.Short() {
507 t.Skip("skipping malloc count in short mode")
508 }
509 if runtime.GOMAXPROCS(0) > 1 {
510 t.Skip("skipping; GOMAXPROCS>1")
511 }
512 fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
513}
514
515func TestCountMallocsOverHTTP(t *testing.T) {
516 if testing.Short() {
517 t.Skip("skipping malloc count in short mode")
518 }
519 if runtime.GOMAXPROCS(0) > 1 {
520 t.Skip("skipping; GOMAXPROCS>1")
521 }
522 fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
523}
524
525type writeCrasher struct {
526 done chan bool
527}
528
529func (writeCrasher) Close() error {
530 return nil
531}
532
533func (w *writeCrasher) Read(p []byte) (int, error) {
534 <-w.done
535 return 0, io.EOF
536}
537
538func (writeCrasher) Write(p []byte) (int, error) {
539 return 0, errors.New("fake write failure")
540}
541
542func TestClientWriteError(t *testing.T) {
543 w := &writeCrasher{done: make(chan bool)}
544 c := NewClient(w)
545 defer c.Close()
546
547 res := false
548 err := c.Call("foo", 1, &res)
549 if err == nil {
550 t.Fatal("expected error")
551 }
552 if err.Error() != "fake write failure" {
553 t.Error("unexpected value of error:", err)
554 }
555 w.done <- true
556}
557
558func TestTCPClose(t *testing.T) {
559 once.Do(startServer)
560
561 client, err := dialHTTP()
562 if err != nil {
563 t.Fatalf("dialing: %v", err)
564 }
565 defer client.Close()
566
567 args := Args{17, 8}
568 var reply Reply
569 err = client.Call("Arith.Mul", args, &reply)
570 if err != nil {
571 t.Fatal("arith error:", err)
572 }
573 t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
574 if reply.C != args.A*args.B {
575 t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
576 }
577}
578
579func TestErrorAfterClientClose(t *testing.T) {
580 once.Do(startServer)
581
582 client, err := dialHTTP()
583 if err != nil {
584 t.Fatalf("dialing: %v", err)
585 }
586 err = client.Close()
587 if err != nil {
588 t.Fatal("close error:", err)
589 }
590 err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
591 if err != ErrShutdown {
592 t.Errorf("Forever: expected ErrShutdown got %v", err)
593 }
594}
595
596func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
597 once.Do(startServer)
598 client, err := dial()
599 if err != nil {
600 b.Fatal("error dialing:", err)
601 }
602 defer client.Close()
603
604 // Synchronous calls
605 args := &Args{7, 8}
606 b.ResetTimer()
607
608 b.RunParallel(func(pb *testing.PB) {
609 reply := new(Reply)
610 for pb.Next() {
611 err := client.Call("Arith.Add", args, reply)
612 if err != nil {
613 b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
614 }
615 if reply.C != args.A+args.B {
616 b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
617 }
618 }
619 })
620}
621
622func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
623 const MaxConcurrentCalls = 100
624 once.Do(startServer)
625 client, err := dial()
626 if err != nil {
627 b.Fatal("error dialing:", err)
628 }
629 defer client.Close()
630
631 // Asynchronous calls
632 args := &Args{7, 8}
633 procs := 4 * runtime.GOMAXPROCS(-1)
634 send := int32(b.N)
635 recv := int32(b.N)
636 var wg sync.WaitGroup
637 wg.Add(procs)
638 gate := make(chan bool, MaxConcurrentCalls)
639 res := make(chan *Call, MaxConcurrentCalls)
640 b.ResetTimer()
641
642 for p := 0; p < procs; p++ {
643 go func() {
644 for atomic.AddInt32(&send, -1) >= 0 {
645 gate <- true
646 reply := new(Reply)
647 client.Go("Arith.Add", args, reply, res)
648 }
649 }()
650 go func() {
651 for call := range res {
652 A := call.Args.(*Args).A
653 B := call.Args.(*Args).B
654 C := call.Reply.(*Reply).C
655 if A+B != C {
656 b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C)
657 }
658 <-gate
659 if atomic.AddInt32(&recv, -1) == 0 {
660 close(res)
661 }
662 }
663 wg.Done()
664 }()
665 }
666 wg.Wait()
667}
668
669func BenchmarkEndToEnd(b *testing.B) {
670 benchmarkEndToEnd(dialDirect, b)
671}
672
673func BenchmarkEndToEndHTTP(b *testing.B) {
674 benchmarkEndToEnd(dialHTTP, b)
675}
676
677func BenchmarkEndToEndAsync(b *testing.B) {
678 benchmarkEndToEndAsync(dialDirect, b)
679}
680
681func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
682 benchmarkEndToEndAsync(dialHTTP, b)
683}