blob: 257468e52d72854667e6c848060e1f68e6a7b25e [file] [log] [blame]
temporal40ee5512008-07-10 02:12:20 +00001#!/usr/bin/python2.4
2#
3# Copyright 2008 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# This file is used for testing. The original is at:
18# http://code.google.com/p/pymox/
19
20"""Mox, an object-mocking framework for Python.
21
22Mox works in the record-replay-verify paradigm. When you first create
23a mock object, it is in record mode. You then programmatically set
24the expected behavior of the mock object (what methods are to be
25called on it, with what parameters, what they should return, and in
26what order).
27
28Once you have set up the expected mock behavior, you put it in replay
29mode. Now the mock responds to method calls just as you told it to.
30If an unexpected method (or an expected method with unexpected
31parameters) is called, then an exception will be raised.
32
33Once you are done interacting with the mock, you need to verify that
Veres Lajosc7680722014-11-08 22:59:34 +000034all the expected interactions occurred. (Maybe your code exited
temporal40ee5512008-07-10 02:12:20 +000035prematurely without calling some cleanup method!) The verify phase
36ensures that every expected method was called; otherwise, an exception
37will be raised.
38
39Suggested usage / workflow:
40
41 # Create Mox factory
42 my_mox = Mox()
43
44 # Create a mock data access object
45 mock_dao = my_mox.CreateMock(DAOClass)
46
47 # Set up expected behavior
48 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
49 mock_dao.DeletePerson(person)
50
51 # Put mocks in replay mode
52 my_mox.ReplayAll()
53
54 # Inject mock object and run test
55 controller.SetDao(mock_dao)
56 controller.DeletePersonById('1')
57
58 # Verify all methods were called as expected
59 my_mox.VerifyAll()
60"""
61
62from collections import deque
63import re
64import types
65import unittest
66
67import stubout
68
69class Error(AssertionError):
70 """Base exception for this module."""
71
72 pass
73
74
75class ExpectedMethodCallsError(Error):
76 """Raised when Verify() is called before all expected methods have been called
77 """
78
79 def __init__(self, expected_methods):
80 """Init exception.
81
82 Args:
83 # expected_methods: A sequence of MockMethod objects that should have been
84 # called.
85 expected_methods: [MockMethod]
86
87 Raises:
88 ValueError: if expected_methods contains no methods.
89 """
90
91 if not expected_methods:
92 raise ValueError("There must be at least one expected method")
93 Error.__init__(self)
94 self._expected_methods = expected_methods
95
96 def __str__(self):
97 calls = "\n".join(["%3d. %s" % (i, m)
98 for i, m in enumerate(self._expected_methods)])
99 return "Verify: Expected methods never called:\n%s" % (calls,)
100
101
102class UnexpectedMethodCallError(Error):
103 """Raised when an unexpected method is called.
104
105 This can occur if a method is called with incorrect parameters, or out of the
106 specified order.
107 """
108
109 def __init__(self, unexpected_method, expected):
110 """Init exception.
111
112 Args:
113 # unexpected_method: MockMethod that was called but was not at the head of
114 # the expected_method queue.
115 # expected: MockMethod or UnorderedGroup the method should have
116 # been in.
117 unexpected_method: MockMethod
118 expected: MockMethod or UnorderedGroup
119 """
120
121 Error.__init__(self)
122 self._unexpected_method = unexpected_method
123 self._expected = expected
124
125 def __str__(self):
126 return "Unexpected method call: %s. Expecting: %s" % \
127 (self._unexpected_method, self._expected)
128
129
130class UnknownMethodCallError(Error):
131 """Raised if an unknown method is requested of the mock object."""
132
133 def __init__(self, unknown_method_name):
134 """Init exception.
135
136 Args:
137 # unknown_method_name: Method call that is not part of the mocked class's
138 # public interface.
139 unknown_method_name: str
140 """
141
142 Error.__init__(self)
143 self._unknown_method_name = unknown_method_name
144
145 def __str__(self):
146 return "Method called is not a member of the object: %s" % \
147 self._unknown_method_name
148
149
150class Mox(object):
151 """Mox: a factory for creating mock objects."""
152
153 # A list of types that should be stubbed out with MockObjects (as
154 # opposed to MockAnythings).
155 _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
156 types.ObjectType, types.TypeType]
157
158 def __init__(self):
159 """Initialize a new Mox."""
160
161 self._mock_objects = []
162 self.stubs = stubout.StubOutForTesting()
163
164 def CreateMock(self, class_to_mock):
165 """Create a new mock object.
166
167 Args:
168 # class_to_mock: the class to be mocked
169 class_to_mock: class
170
171 Returns:
172 MockObject that can be used as the class_to_mock would be.
173 """
174
175 new_mock = MockObject(class_to_mock)
176 self._mock_objects.append(new_mock)
177 return new_mock
178
179 def CreateMockAnything(self):
180 """Create a mock that will accept any method calls.
181
182 This does not enforce an interface.
183 """
184
185 new_mock = MockAnything()
186 self._mock_objects.append(new_mock)
187 return new_mock
188
189 def ReplayAll(self):
190 """Set all mock objects to replay mode."""
191
192 for mock_obj in self._mock_objects:
193 mock_obj._Replay()
194
195
196 def VerifyAll(self):
197 """Call verify on all mock objects created."""
198
199 for mock_obj in self._mock_objects:
200 mock_obj._Verify()
201
202 def ResetAll(self):
203 """Call reset on all mock objects. This does not unset stubs."""
204
205 for mock_obj in self._mock_objects:
206 mock_obj._Reset()
207
208 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
209 """Replace a method, attribute, etc. with a Mock.
210
211 This will replace a class or module with a MockObject, and everything else
212 (method, function, etc) with a MockAnything. This can be overridden to
213 always use a MockAnything by setting use_mock_anything to True.
214
215 Args:
216 obj: A Python object (class, module, instance, callable).
217 attr_name: str. The name of the attribute to replace with a mock.
218 use_mock_anything: bool. True if a MockAnything should be used regardless
219 of the type of attribute.
220 """
221
222 attr_to_replace = getattr(obj, attr_name)
223 if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
224 stub = self.CreateMock(attr_to_replace)
225 else:
226 stub = self.CreateMockAnything()
227
228 self.stubs.Set(obj, attr_name, stub)
229
230 def UnsetStubs(self):
231 """Restore stubs to their original state."""
232
233 self.stubs.UnsetAll()
234
235def Replay(*args):
236 """Put mocks into Replay mode.
237
238 Args:
239 # args is any number of mocks to put into replay mode.
240 """
241
242 for mock in args:
243 mock._Replay()
244
245
246def Verify(*args):
247 """Verify mocks.
248
249 Args:
250 # args is any number of mocks to be verified.
251 """
252
253 for mock in args:
254 mock._Verify()
255
256
257def Reset(*args):
258 """Reset mocks.
259
260 Args:
261 # args is any number of mocks to be reset.
262 """
263
264 for mock in args:
265 mock._Reset()
266
267
268class MockAnything:
269 """A mock that can be used to mock anything.
270
271 This is helpful for mocking classes that do not provide a public interface.
272 """
273
274 def __init__(self):
275 """ """
276 self._Reset()
277
278 def __getattr__(self, method_name):
279 """Intercept method calls on this object.
280
281 A new MockMethod is returned that is aware of the MockAnything's
282 state (record or replay). The call will be recorded or replayed
283 by the MockMethod's __call__.
284
285 Args:
286 # method name: the name of the method being called.
287 method_name: str
288
289 Returns:
290 A new MockMethod aware of MockAnything's state (record or replay).
291 """
292
293 return self._CreateMockMethod(method_name)
294
295 def _CreateMockMethod(self, method_name):
296 """Create a new mock method call and return it.
297
298 Args:
299 # method name: the name of the method being called.
300 method_name: str
301
302 Returns:
303 A new MockMethod aware of MockAnything's state (record or replay).
304 """
305
306 return MockMethod(method_name, self._expected_calls_queue,
307 self._replay_mode)
308
309 def __nonzero__(self):
310 """Return 1 for nonzero so the mock can be used as a conditional."""
311
312 return 1
313
314 def __eq__(self, rhs):
315 """Provide custom logic to compare objects."""
316
317 return (isinstance(rhs, MockAnything) and
318 self._replay_mode == rhs._replay_mode and
319 self._expected_calls_queue == rhs._expected_calls_queue)
320
321 def __ne__(self, rhs):
322 """Provide custom logic to compare objects."""
323
324 return not self == rhs
325
326 def _Replay(self):
327 """Start replaying expected method calls."""
328
329 self._replay_mode = True
330
331 def _Verify(self):
332 """Verify that all of the expected calls have been made.
333
334 Raises:
335 ExpectedMethodCallsError: if there are still more method calls in the
336 expected queue.
337 """
338
339 # If the list of expected calls is not empty, raise an exception
340 if self._expected_calls_queue:
341 # The last MultipleTimesGroup is not popped from the queue.
342 if (len(self._expected_calls_queue) == 1 and
343 isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
344 self._expected_calls_queue[0].IsSatisfied()):
345 pass
346 else:
347 raise ExpectedMethodCallsError(self._expected_calls_queue)
348
349 def _Reset(self):
350 """Reset the state of this mock to record mode with an empty queue."""
351
352 # Maintain a list of method calls we are expecting
353 self._expected_calls_queue = deque()
354
355 # Make sure we are in setup mode, not replay mode
356 self._replay_mode = False
357
358
359class MockObject(MockAnything, object):
360 """A mock object that simulates the public/protected interface of a class."""
361
362 def __init__(self, class_to_mock):
363 """Initialize a mock object.
364
365 This determines the methods and properties of the class and stores them.
366
367 Args:
368 # class_to_mock: class to be mocked
369 class_to_mock: class
370 """
371
372 # This is used to hack around the mixin/inheritance of MockAnything, which
373 # is not a proper object (it can be anything. :-)
374 MockAnything.__dict__['__init__'](self)
375
376 # Get a list of all the public and special methods we should mock.
377 self._known_methods = set()
378 self._known_vars = set()
379 self._class_to_mock = class_to_mock
380 for method in dir(class_to_mock):
381 if callable(getattr(class_to_mock, method)):
382 self._known_methods.add(method)
383 else:
384 self._known_vars.add(method)
385
386 def __getattr__(self, name):
387 """Intercept attribute request on this object.
388
389 If the attribute is a public class variable, it will be returned and not
390 recorded as a call.
391
392 If the attribute is not a variable, it is handled like a method
393 call. The method name is checked against the set of mockable
394 methods, and a new MockMethod is returned that is aware of the
395 MockObject's state (record or replay). The call will be recorded
396 or replayed by the MockMethod's __call__.
397
398 Args:
399 # name: the name of the attribute being requested.
400 name: str
401
402 Returns:
403 Either a class variable or a new MockMethod that is aware of the state
404 of the mock (record or replay).
405
406 Raises:
407 UnknownMethodCallError if the MockObject does not mock the requested
408 method.
409 """
410
411 if name in self._known_vars:
412 return getattr(self._class_to_mock, name)
413
414 if name in self._known_methods:
415 return self._CreateMockMethod(name)
416
417 raise UnknownMethodCallError(name)
418
419 def __eq__(self, rhs):
420 """Provide custom logic to compare objects."""
421
422 return (isinstance(rhs, MockObject) and
423 self._class_to_mock == rhs._class_to_mock and
424 self._replay_mode == rhs._replay_mode and
425 self._expected_calls_queue == rhs._expected_calls_queue)
426
427 def __setitem__(self, key, value):
428 """Provide custom logic for mocking classes that support item assignment.
429
430 Args:
431 key: Key to set the value for.
432 value: Value to set.
433
434 Returns:
435 Expected return value in replay mode. A MockMethod object for the
436 __setitem__ method that has already been called if not in replay mode.
437
438 Raises:
439 TypeError if the underlying class does not support item assignment.
440 UnexpectedMethodCallError if the object does not expect the call to
441 __setitem__.
442
443 """
444 setitem = self._class_to_mock.__dict__.get('__setitem__', None)
445
446 # Verify the class supports item assignment.
447 if setitem is None:
448 raise TypeError('object does not support item assignment')
449
450 # If we are in replay mode then simply call the mock __setitem__ method.
451 if self._replay_mode:
452 return MockMethod('__setitem__', self._expected_calls_queue,
453 self._replay_mode)(key, value)
454
455
456 # Otherwise, create a mock method __setitem__.
457 return self._CreateMockMethod('__setitem__')(key, value)
458
459 def __getitem__(self, key):
460 """Provide custom logic for mocking classes that are subscriptable.
461
462 Args:
463 key: Key to return the value for.
464
465 Returns:
466 Expected return value in replay mode. A MockMethod object for the
467 __getitem__ method that has already been called if not in replay mode.
468
469 Raises:
470 TypeError if the underlying class is not subscriptable.
471 UnexpectedMethodCallError if the object does not expect the call to
472 __setitem__.
473
474 """
475 getitem = self._class_to_mock.__dict__.get('__getitem__', None)
476
477 # Verify the class supports item assignment.
478 if getitem is None:
479 raise TypeError('unsubscriptable object')
480
481 # If we are in replay mode then simply call the mock __getitem__ method.
482 if self._replay_mode:
483 return MockMethod('__getitem__', self._expected_calls_queue,
484 self._replay_mode)(key)
485
486
487 # Otherwise, create a mock method __getitem__.
488 return self._CreateMockMethod('__getitem__')(key)
489
490 def __call__(self, *params, **named_params):
491 """Provide custom logic for mocking classes that are callable."""
492
493 # Verify the class we are mocking is callable
494 callable = self._class_to_mock.__dict__.get('__call__', None)
495 if callable is None:
496 raise TypeError('Not callable')
497
498 # Because the call is happening directly on this object instead of a method,
499 # the call on the mock method is made right here
500 mock_method = self._CreateMockMethod('__call__')
501 return mock_method(*params, **named_params)
502
503 @property
504 def __class__(self):
505 """Return the class that is being mocked."""
506
507 return self._class_to_mock
508
509
510class MockMethod(object):
511 """Callable mock method.
512
513 A MockMethod should act exactly like the method it mocks, accepting parameters
514 and returning a value, or throwing an exception (as specified). When this
515 method is called, it can optionally verify whether the called method (name and
516 signature) matches the expected method.
517 """
518
519 def __init__(self, method_name, call_queue, replay_mode):
520 """Construct a new mock method.
521
522 Args:
523 # method_name: the name of the method
524 # call_queue: deque of calls, verify this call against the head, or add
525 # this call to the queue.
526 # replay_mode: False if we are recording, True if we are verifying calls
527 # against the call queue.
528 method_name: str
529 call_queue: list or deque
530 replay_mode: bool
531 """
532
533 self._name = method_name
534 self._call_queue = call_queue
535 if not isinstance(call_queue, deque):
536 self._call_queue = deque(self._call_queue)
537 self._replay_mode = replay_mode
538
539 self._params = None
540 self._named_params = None
541 self._return_value = None
542 self._exception = None
543 self._side_effects = None
544
545 def __call__(self, *params, **named_params):
546 """Log parameters and return the specified return value.
547
548 If the Mock(Anything/Object) associated with this call is in record mode,
549 this MockMethod will be pushed onto the expected call queue. If the mock
550 is in replay mode, this will pop a MockMethod off the top of the queue and
551 verify this call is equal to the expected call.
552
553 Raises:
554 UnexpectedMethodCall if this call is supposed to match an expected method
555 call and it does not.
556 """
557
558 self._params = params
559 self._named_params = named_params
560
561 if not self._replay_mode:
562 self._call_queue.append(self)
563 return self
564
565 expected_method = self._VerifyMethodCall()
566
567 if expected_method._side_effects:
568 expected_method._side_effects(*params, **named_params)
569
570 if expected_method._exception:
571 raise expected_method._exception
572
573 return expected_method._return_value
574
575 def __getattr__(self, name):
576 """Raise an AttributeError with a helpful message."""
577
578 raise AttributeError('MockMethod has no attribute "%s". '
579 'Did you remember to put your mocks in replay mode?' % name)
580
581 def _PopNextMethod(self):
582 """Pop the next method from our call queue."""
583 try:
584 return self._call_queue.popleft()
585 except IndexError:
586 raise UnexpectedMethodCallError(self, None)
587
588 def _VerifyMethodCall(self):
589 """Verify the called method is expected.
590
591 This can be an ordered method, or part of an unordered set.
592
593 Returns:
594 The expected mock method.
595
596 Raises:
597 UnexpectedMethodCall if the method called was not expected.
598 """
599
600 expected = self._PopNextMethod()
601
602 # Loop here, because we might have a MethodGroup followed by another
603 # group.
604 while isinstance(expected, MethodGroup):
605 expected, method = expected.MethodCalled(self)
606 if method is not None:
607 return method
608
609 # This is a mock method, so just check equality.
610 if expected != self:
611 raise UnexpectedMethodCallError(self, expected)
612
613 return expected
614
615 def __str__(self):
616 params = ', '.join(
617 [repr(p) for p in self._params or []] +
618 ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
619 desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
620 return desc
621
622 def __eq__(self, rhs):
623 """Test whether this MockMethod is equivalent to another MockMethod.
624
625 Args:
626 # rhs: the right hand side of the test
627 rhs: MockMethod
628 """
629
630 return (isinstance(rhs, MockMethod) and
631 self._name == rhs._name and
632 self._params == rhs._params and
633 self._named_params == rhs._named_params)
634
635 def __ne__(self, rhs):
636 """Test whether this MockMethod is not equivalent to another MockMethod.
637
638 Args:
639 # rhs: the right hand side of the test
640 rhs: MockMethod
641 """
642
643 return not self == rhs
644
645 def GetPossibleGroup(self):
646 """Returns a possible group from the end of the call queue or None if no
647 other methods are on the stack.
648 """
649
650 # Remove this method from the tail of the queue so we can add it to a group.
651 this_method = self._call_queue.pop()
652 assert this_method == self
653
654 # Determine if the tail of the queue is a group, or just a regular ordered
655 # mock method.
656 group = None
657 try:
658 group = self._call_queue[-1]
659 except IndexError:
660 pass
661
662 return group
663
664 def _CheckAndCreateNewGroup(self, group_name, group_class):
665 """Checks if the last method (a possible group) is an instance of our
666 group_class. Adds the current method to this group or creates a new one.
667
668 Args:
669
670 group_name: the name of the group.
671 group_class: the class used to create instance of this new group
672 """
673 group = self.GetPossibleGroup()
674
675 # If this is a group, and it is the correct group, add the method.
676 if isinstance(group, group_class) and group.group_name() == group_name:
677 group.AddMethod(self)
678 return self
679
680 # Create a new group and add the method.
681 new_group = group_class(group_name)
682 new_group.AddMethod(self)
683 self._call_queue.append(new_group)
684 return self
685
686 def InAnyOrder(self, group_name="default"):
687 """Move this method into a group of unordered calls.
688
689 A group of unordered calls must be defined together, and must be executed
690 in full before the next expected method can be called. There can be
691 multiple groups that are expected serially, if they are given
692 different group names. The same group name can be reused if there is a
693 standard method call, or a group with a different name, spliced between
694 usages.
695
696 Args:
697 group_name: the name of the unordered group.
698
699 Returns:
700 self
701 """
702 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
703
704 def MultipleTimes(self, group_name="default"):
705 """Move this method into group of calls which may be called multiple times.
706
707 A group of repeating calls must be defined together, and must be executed in
708 full before the next expected mehtod can be called.
709
710 Args:
711 group_name: the name of the unordered group.
712
713 Returns:
714 self
715 """
716 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
717
718 def AndReturn(self, return_value):
719 """Set the value to return when this method is called.
720
721 Args:
722 # return_value can be anything.
723 """
724
725 self._return_value = return_value
726 return return_value
727
728 def AndRaise(self, exception):
729 """Set the exception to raise when this method is called.
730
731 Args:
732 # exception: the exception to raise when this method is called.
733 exception: Exception
734 """
735
736 self._exception = exception
737
738 def WithSideEffects(self, side_effects):
739 """Set the side effects that are simulated when this method is called.
740
741 Args:
742 side_effects: A callable which modifies the parameters or other relevant
743 state which a given test case depends on.
744
745 Returns:
746 Self for chaining with AndReturn and AndRaise.
747 """
748 self._side_effects = side_effects
749 return self
750
751class Comparator:
752 """Base class for all Mox comparators.
753
754 A Comparator can be used as a parameter to a mocked method when the exact
755 value is not known. For example, the code you are testing might build up a
756 long SQL string that is passed to your mock DAO. You're only interested that
757 the IN clause contains the proper primary keys, so you can set your mock
758 up as follows:
759
760 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
761
762 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
763
764 A Comparator may replace one or more parameters, for example:
765 # return at most 10 rows
766 mock_dao.RunQuery(StrContains('SELECT'), 10)
767
768 or
769
770 # Return some non-deterministic number of rows
771 mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
772 """
773
774 def equals(self, rhs):
775 """Special equals method that all comparators must implement.
776
777 Args:
778 rhs: any python object
779 """
780
781 raise NotImplementedError, 'method must be implemented by a subclass.'
782
783 def __eq__(self, rhs):
784 return self.equals(rhs)
785
786 def __ne__(self, rhs):
787 return not self.equals(rhs)
788
789
790class IsA(Comparator):
791 """This class wraps a basic Python type or class. It is used to verify
792 that a parameter is of the given type or class.
793
794 Example:
795 mock_dao.Connect(IsA(DbConnectInfo))
796 """
797
798 def __init__(self, class_name):
799 """Initialize IsA
800
801 Args:
802 class_name: basic python type or a class
803 """
804
805 self._class_name = class_name
806
807 def equals(self, rhs):
808 """Check to see if the RHS is an instance of class_name.
809
810 Args:
811 # rhs: the right hand side of the test
812 rhs: object
813
814 Returns:
815 bool
816 """
817
818 try:
819 return isinstance(rhs, self._class_name)
820 except TypeError:
821 # Check raw types if there was a type error. This is helpful for
822 # things like cStringIO.StringIO.
823 return type(rhs) == type(self._class_name)
824
825 def __repr__(self):
826 return str(self._class_name)
827
828class IsAlmost(Comparator):
829 """Comparison class used to check whether a parameter is nearly equal
830 to a given value. Generally useful for floating point numbers.
831
832 Example mock_dao.SetTimeout((IsAlmost(3.9)))
833 """
834
835 def __init__(self, float_value, places=7):
836 """Initialize IsAlmost.
837
838 Args:
839 float_value: The value for making the comparison.
840 places: The number of decimal places to round to.
841 """
842
843 self._float_value = float_value
844 self._places = places
845
846 def equals(self, rhs):
847 """Check to see if RHS is almost equal to float_value
848
849 Args:
850 rhs: the value to compare to float_value
851
852 Returns:
853 bool
854 """
855
856 try:
857 return round(rhs-self._float_value, self._places) == 0
858 except TypeError:
859 # This is probably because either float_value or rhs is not a number.
860 return False
861
862 def __repr__(self):
863 return str(self._float_value)
864
865class StrContains(Comparator):
866 """Comparison class used to check whether a substring exists in a
867 string parameter. This can be useful in mocking a database with SQL
868 passed in as a string parameter, for example.
869
870 Example:
871 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
872 """
873
874 def __init__(self, search_string):
875 """Initialize.
876
877 Args:
878 # search_string: the string you are searching for
879 search_string: str
880 """
881
882 self._search_string = search_string
883
884 def equals(self, rhs):
885 """Check to see if the search_string is contained in the rhs string.
886
887 Args:
888 # rhs: the right hand side of the test
889 rhs: object
890
891 Returns:
892 bool
893 """
894
895 try:
896 return rhs.find(self._search_string) > -1
897 except Exception:
898 return False
899
900 def __repr__(self):
901 return '<str containing \'%s\'>' % self._search_string
902
903
904class Regex(Comparator):
905 """Checks if a string matches a regular expression.
906
907 This uses a given regular expression to determine equality.
908 """
909
910 def __init__(self, pattern, flags=0):
911 """Initialize.
912
913 Args:
914 # pattern is the regular expression to search for
915 pattern: str
916 # flags passed to re.compile function as the second argument
917 flags: int
918 """
919
920 self.regex = re.compile(pattern, flags=flags)
921
922 def equals(self, rhs):
923 """Check to see if rhs matches regular expression pattern.
924
925 Returns:
926 bool
927 """
928
929 return self.regex.search(rhs) is not None
930
931 def __repr__(self):
932 s = '<regular expression \'%s\'' % self.regex.pattern
933 if self.regex.flags:
934 s += ', flags=%d' % self.regex.flags
935 s += '>'
936 return s
937
938
939class In(Comparator):
940 """Checks whether an item (or key) is in a list (or dict) parameter.
941
942 Example:
943 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
944 """
945
946 def __init__(self, key):
947 """Initialize.
948
949 Args:
950 # key is any thing that could be in a list or a key in a dict
951 """
952
953 self._key = key
954
955 def equals(self, rhs):
956 """Check to see whether key is in rhs.
957
958 Args:
959 rhs: dict
960
961 Returns:
962 bool
963 """
964
965 return self._key in rhs
966
967 def __repr__(self):
968 return '<sequence or map containing \'%s\'>' % self._key
969
970
971class ContainsKeyValue(Comparator):
972 """Checks whether a key/value pair is in a dict parameter.
973
974 Example:
975 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
976 """
977
978 def __init__(self, key, value):
979 """Initialize.
980
981 Args:
982 # key: a key in a dict
983 # value: the corresponding value
984 """
985
986 self._key = key
987 self._value = value
988
989 def equals(self, rhs):
990 """Check whether the given key/value pair is in the rhs dict.
991
992 Returns:
993 bool
994 """
995
996 try:
997 return rhs[self._key] == self._value
998 except Exception:
999 return False
1000
1001 def __repr__(self):
1002 return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
1003
1004
1005class SameElementsAs(Comparator):
1006 """Checks whether iterables contain the same elements (ignoring order).
1007
1008 Example:
1009 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1010 """
1011
1012 def __init__(self, expected_seq):
1013 """Initialize.
1014
1015 Args:
1016 expected_seq: a sequence
1017 """
1018
1019 self._expected_seq = expected_seq
1020
1021 def equals(self, actual_seq):
1022 """Check to see whether actual_seq has same elements as expected_seq.
1023
1024 Args:
1025 actual_seq: sequence
1026
1027 Returns:
1028 bool
1029 """
1030
1031 try:
1032 expected = dict([(element, None) for element in self._expected_seq])
1033 actual = dict([(element, None) for element in actual_seq])
1034 except TypeError:
1035 # Fall back to slower list-compare if any of the objects are unhashable.
1036 expected = list(self._expected_seq)
1037 actual = list(actual_seq)
1038 expected.sort()
1039 actual.sort()
1040 return expected == actual
1041
1042 def __repr__(self):
1043 return '<sequence with same elements as \'%s\'>' % self._expected_seq
1044
1045
1046class And(Comparator):
1047 """Evaluates one or more Comparators on RHS and returns an AND of the results.
1048 """
1049
1050 def __init__(self, *args):
1051 """Initialize.
1052
1053 Args:
1054 *args: One or more Comparator
1055 """
1056
1057 self._comparators = args
1058
1059 def equals(self, rhs):
1060 """Checks whether all Comparators are equal to rhs.
1061
1062 Args:
1063 # rhs: can be anything
1064
1065 Returns:
1066 bool
1067 """
1068
1069 for comparator in self._comparators:
1070 if not comparator.equals(rhs):
1071 return False
1072
1073 return True
1074
1075 def __repr__(self):
1076 return '<AND %s>' % str(self._comparators)
1077
1078
1079class Or(Comparator):
1080 """Evaluates one or more Comparators on RHS and returns an OR of the results.
1081 """
1082
1083 def __init__(self, *args):
1084 """Initialize.
1085
1086 Args:
1087 *args: One or more Mox comparators
1088 """
1089
1090 self._comparators = args
1091
1092 def equals(self, rhs):
1093 """Checks whether any Comparator is equal to rhs.
1094
1095 Args:
1096 # rhs: can be anything
1097
1098 Returns:
1099 bool
1100 """
1101
1102 for comparator in self._comparators:
1103 if comparator.equals(rhs):
1104 return True
1105
1106 return False
1107
1108 def __repr__(self):
1109 return '<OR %s>' % str(self._comparators)
1110
1111
1112class Func(Comparator):
1113 """Call a function that should verify the parameter passed in is correct.
1114
1115 You may need the ability to perform more advanced operations on the parameter
1116 in order to validate it. You can use this to have a callable validate any
1117 parameter. The callable should return either True or False.
1118
1119
1120 Example:
1121
1122 def myParamValidator(param):
1123 # Advanced logic here
1124 return True
1125
1126 mock_dao.DoSomething(Func(myParamValidator), true)
1127 """
1128
1129 def __init__(self, func):
1130 """Initialize.
1131
1132 Args:
1133 func: callable that takes one parameter and returns a bool
1134 """
1135
1136 self._func = func
1137
1138 def equals(self, rhs):
1139 """Test whether rhs passes the function test.
1140
1141 rhs is passed into func.
1142
1143 Args:
1144 rhs: any python object
1145
1146 Returns:
1147 the result of func(rhs)
1148 """
1149
1150 return self._func(rhs)
1151
1152 def __repr__(self):
1153 return str(self._func)
1154
1155
1156class IgnoreArg(Comparator):
1157 """Ignore an argument.
1158
1159 This can be used when we don't care about an argument of a method call.
1160
1161 Example:
1162 # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
1163 mymock.CastMagic(3, IgnoreArg(), 'disappear')
1164 """
1165
1166 def equals(self, unused_rhs):
1167 """Ignores arguments and returns True.
1168
1169 Args:
1170 unused_rhs: any python object
1171
1172 Returns:
1173 always returns True
1174 """
1175
1176 return True
1177
1178 def __repr__(self):
1179 return '<IgnoreArg>'
1180
1181
1182class MethodGroup(object):
1183 """Base class containing common behaviour for MethodGroups."""
1184
1185 def __init__(self, group_name):
1186 self._group_name = group_name
1187
1188 def group_name(self):
1189 return self._group_name
1190
1191 def __str__(self):
1192 return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1193
1194 def AddMethod(self, mock_method):
1195 raise NotImplementedError
1196
1197 def MethodCalled(self, mock_method):
1198 raise NotImplementedError
1199
1200 def IsSatisfied(self):
1201 raise NotImplementedError
1202
1203class UnorderedGroup(MethodGroup):
1204 """UnorderedGroup holds a set of method calls that may occur in any order.
1205
1206 This construct is helpful for non-deterministic events, such as iterating
1207 over the keys of a dict.
1208 """
1209
1210 def __init__(self, group_name):
1211 super(UnorderedGroup, self).__init__(group_name)
1212 self._methods = []
1213
1214 def AddMethod(self, mock_method):
1215 """Add a method to this group.
1216
1217 Args:
1218 mock_method: A mock method to be added to this group.
1219 """
1220
1221 self._methods.append(mock_method)
1222
1223 def MethodCalled(self, mock_method):
1224 """Remove a method call from the group.
1225
1226 If the method is not in the set, an UnexpectedMethodCallError will be
1227 raised.
1228
1229 Args:
1230 mock_method: a mock method that should be equal to a method in the group.
1231
1232 Returns:
1233 The mock method from the group
1234
1235 Raises:
1236 UnexpectedMethodCallError if the mock_method was not in the group.
1237 """
1238
1239 # Check to see if this method exists, and if so, remove it from the set
1240 # and return it.
1241 for method in self._methods:
1242 if method == mock_method:
1243 # Remove the called mock_method instead of the method in the group.
1244 # The called method will match any comparators when equality is checked
1245 # during removal. The method in the group could pass a comparator to
1246 # another comparator during the equality check.
1247 self._methods.remove(mock_method)
1248
1249 # If this group is not empty, put it back at the head of the queue.
1250 if not self.IsSatisfied():
1251 mock_method._call_queue.appendleft(self)
1252
1253 return self, method
1254
1255 raise UnexpectedMethodCallError(mock_method, self)
1256
1257 def IsSatisfied(self):
1258 """Return True if there are not any methods in this group."""
1259
1260 return len(self._methods) == 0
1261
1262
1263class MultipleTimesGroup(MethodGroup):
1264 """MultipleTimesGroup holds methods that may be called any number of times.
1265
1266 Note: Each method must be called at least once.
1267
1268 This is helpful, if you don't know or care how many times a method is called.
1269 """
1270
1271 def __init__(self, group_name):
1272 super(MultipleTimesGroup, self).__init__(group_name)
1273 self._methods = set()
1274 self._methods_called = set()
1275
1276 def AddMethod(self, mock_method):
1277 """Add a method to this group.
1278
1279 Args:
1280 mock_method: A mock method to be added to this group.
1281 """
1282
1283 self._methods.add(mock_method)
1284
1285 def MethodCalled(self, mock_method):
1286 """Remove a method call from the group.
1287
1288 If the method is not in the set, an UnexpectedMethodCallError will be
1289 raised.
1290
1291 Args:
1292 mock_method: a mock method that should be equal to a method in the group.
1293
1294 Returns:
1295 The mock method from the group
1296
1297 Raises:
1298 UnexpectedMethodCallError if the mock_method was not in the group.
1299 """
1300
1301 # Check to see if this method exists, and if so add it to the set of
1302 # called methods.
1303
1304 for method in self._methods:
1305 if method == mock_method:
1306 self._methods_called.add(mock_method)
1307 # Always put this group back on top of the queue, because we don't know
1308 # when we are done.
1309 mock_method._call_queue.appendleft(self)
1310 return self, method
1311
1312 if self.IsSatisfied():
1313 next_method = mock_method._PopNextMethod();
1314 return next_method, None
1315 else:
1316 raise UnexpectedMethodCallError(mock_method, self)
1317
1318 def IsSatisfied(self):
1319 """Return True if all methods in this group are called at least once."""
1320 # NOTE(psycho): We can't use the simple set difference here because we want
1321 # to match different parameters which are considered the same e.g. IsA(str)
1322 # and some string. This solution is O(n^2) but n should be small.
1323 tmp = self._methods.copy()
1324 for called in self._methods_called:
1325 for expected in tmp:
1326 if called == expected:
1327 tmp.remove(expected)
1328 if not tmp:
1329 return True
1330 break
1331 return False
1332
1333
1334class MoxMetaTestBase(type):
1335 """Metaclass to add mox cleanup and verification to every test.
1336
1337 As the mox unit testing class is being constructed (MoxTestBase or a
1338 subclass), this metaclass will modify all test functions to call the
1339 CleanUpMox method of the test class after they finish. This means that
1340 unstubbing and verifying will happen for every test with no additional code,
1341 and any failures will result in test failures as opposed to errors.
1342 """
1343
1344 def __init__(cls, name, bases, d):
1345 type.__init__(cls, name, bases, d)
1346
1347 # also get all the attributes from the base classes to account
1348 # for a case when test class is not the immediate child of MoxTestBase
1349 for base in bases:
1350 for attr_name in dir(base):
1351 d[attr_name] = getattr(base, attr_name)
1352
1353 for func_name, func in d.items():
1354 if func_name.startswith('test') and callable(func):
1355 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
1356
1357 @staticmethod
1358 def CleanUpTest(cls, func):
1359 """Adds Mox cleanup code to any MoxTestBase method.
1360
1361 Always unsets stubs after a test. Will verify all mocks for tests that
1362 otherwise pass.
1363
1364 Args:
1365 cls: MoxTestBase or subclass; the class whose test method we are altering.
1366 func: method; the method of the MoxTestBase test class we wish to alter.
1367
1368 Returns:
1369 The modified method.
1370 """
1371 def new_method(self, *args, **kwargs):
1372 mox_obj = getattr(self, 'mox', None)
1373 cleanup_mox = False
1374 if mox_obj and isinstance(mox_obj, Mox):
1375 cleanup_mox = True
1376 try:
1377 func(self, *args, **kwargs)
1378 finally:
1379 if cleanup_mox:
1380 mox_obj.UnsetStubs()
1381 if cleanup_mox:
1382 mox_obj.VerifyAll()
1383 new_method.__name__ = func.__name__
1384 new_method.__doc__ = func.__doc__
1385 new_method.__module__ = func.__module__
1386 return new_method
1387
1388
1389class MoxTestBase(unittest.TestCase):
1390 """Convenience test class to make stubbing easier.
1391
1392 Sets up a "mox" attribute which is an instance of Mox - any mox tests will
1393 want this. Also automatically unsets any stubs and verifies that all mock
1394 methods have been called at the end of each test, eliminating boilerplate
1395 code.
1396 """
1397
1398 __metaclass__ = MoxMetaTestBase
1399
1400 def setUp(self):
1401 self.mox = Mox()