blob: 01e26c8dfaf259c37c70b61d6a8f506467cc9392 [file] [log] [blame]
Emily Morehouse8f59ee02019-01-24 16:49:56 -07001import os
2import unittest
3
Miss Islington (bot)927f07c2019-10-14 04:40:15 -07004GLOBAL_VAR = None
Emily Morehouse8f59ee02019-01-24 16:49:56 -07005
6class NamedExpressionInvalidTest(unittest.TestCase):
7
8 def test_named_expression_invalid_01(self):
9 code = """x := 0"""
10
11 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
12 exec(code, {}, {})
13
14 def test_named_expression_invalid_02(self):
15 code = """x = y := 0"""
16
17 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
18 exec(code, {}, {})
19
20 def test_named_expression_invalid_03(self):
21 code = """y := f(x)"""
22
23 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
24 exec(code, {}, {})
25
26 def test_named_expression_invalid_04(self):
27 code = """y0 = y1 := f(x)"""
28
29 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
30 exec(code, {}, {})
31
32 def test_named_expression_invalid_06(self):
33 code = """((a, b) := (1, 2))"""
34
35 with self.assertRaisesRegex(SyntaxError, "cannot use named assignment with tuple"):
36 exec(code, {}, {})
37
38 def test_named_expression_invalid_07(self):
39 code = """def spam(a = b := 42): pass"""
40
41 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
42 exec(code, {}, {})
43
44 def test_named_expression_invalid_08(self):
45 code = """def spam(a: b := 42 = 5): pass"""
46
47 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
48 exec(code, {}, {})
49
50 def test_named_expression_invalid_09(self):
51 code = """spam(a=b := 'c')"""
52
53 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
54 exec(code, {}, {})
55
56 def test_named_expression_invalid_10(self):
57 code = """spam(x = y := f(x))"""
58
59 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
60 exec(code, {}, {})
61
62 def test_named_expression_invalid_11(self):
63 code = """spam(a=1, b := 2)"""
64
65 with self.assertRaisesRegex(SyntaxError,
66 "positional argument follows keyword argument"):
67 exec(code, {}, {})
68
69 def test_named_expression_invalid_12(self):
70 code = """spam(a=1, (b := 2))"""
71
72 with self.assertRaisesRegex(SyntaxError,
73 "positional argument follows keyword argument"):
74 exec(code, {}, {})
75
76 def test_named_expression_invalid_13(self):
77 code = """spam(a=1, (b := 2))"""
78
79 with self.assertRaisesRegex(SyntaxError,
80 "positional argument follows keyword argument"):
81 exec(code, {}, {})
82
83 def test_named_expression_invalid_14(self):
84 code = """(x := lambda: y := 1)"""
85
86 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
87 exec(code, {}, {})
88
89 def test_named_expression_invalid_15(self):
90 code = """(lambda: x := 1)"""
91
92 with self.assertRaisesRegex(SyntaxError,
93 "cannot use named assignment with lambda"):
94 exec(code, {}, {})
95
96 def test_named_expression_invalid_16(self):
97 code = "[i + 1 for i in i := [1,2]]"
98
99 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
100 exec(code, {}, {})
101
102 def test_named_expression_invalid_17(self):
103 code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]"
104
105 with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
106 exec(code, {}, {})
107
Nick Coghlan6ca03072019-08-26 00:41:47 +1000108 def test_named_expression_invalid_in_class_body(self):
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700109 code = """class Foo():
110 [(42, 1 + ((( j := i )))) for i in range(5)]
111 """
112
Nick Coghlan6ca03072019-08-26 00:41:47 +1000113 with self.assertRaisesRegex(SyntaxError,
114 "assignment expression within a comprehension cannot be used in a class body"):
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700115 exec(code, {}, {})
116
Nick Coghlan6ca03072019-08-26 00:41:47 +1000117 def test_named_expression_invalid_rebinding_comprehension_iteration_variable(self):
118 cases = [
119 ("Local reuse", 'i', "[i := 0 for i in range(5)]"),
120 ("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"),
121 ("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"),
122 ("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"),
123 ("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"),
124 ("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"),
125 ("Unreachable nested reuse", 'i',
126 "[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"),
127 ]
128 for case, target, code in cases:
129 msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
130 with self.subTest(case=case):
131 with self.assertRaisesRegex(SyntaxError, msg):
132 exec(code, {}, {})
133
134 def test_named_expression_invalid_rebinding_comprehension_inner_loop(self):
135 cases = [
136 ("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"),
137 ("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"),
138 ]
139 for case, target, code in cases:
140 msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
141 with self.subTest(case=case):
142 with self.assertRaisesRegex(SyntaxError, msg):
143 exec(code, {}) # Module scope
144 with self.assertRaisesRegex(SyntaxError, msg):
145 exec(code, {}, {}) # Class scope
146 with self.assertRaisesRegex(SyntaxError, msg):
147 exec(f"lambda: {code}", {}) # Function scope
148
149 def test_named_expression_invalid_comprehension_iterable_expression(self):
150 cases = [
151 ("Top level", "[i for i in (i := range(5))]"),
152 ("Inside tuple", "[i for i in (2, 3, i := range(5))]"),
153 ("Inside list", "[i for i in [2, 3, i := range(5)]]"),
154 ("Different name", "[i for i in (j := range(5))]"),
155 ("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"),
156 ("Inner loop", "[i for i in range(5) for j in (i := range(5))]"),
157 ("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"),
158 ("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"),
159 ("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"),
160 ]
161 msg = "assignment expression cannot be used in a comprehension iterable expression"
162 for case, code in cases:
163 with self.subTest(case=case):
164 with self.assertRaisesRegex(SyntaxError, msg):
165 exec(code, {}) # Module scope
166 with self.assertRaisesRegex(SyntaxError, msg):
167 exec(code, {}, {}) # Class scope
168 with self.assertRaisesRegex(SyntaxError, msg):
169 exec(f"lambda: {code}", {}) # Function scope
170
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700171
172class NamedExpressionAssignmentTest(unittest.TestCase):
173
174 def test_named_expression_assignment_01(self):
175 (a := 10)
176
177 self.assertEqual(a, 10)
178
179 def test_named_expression_assignment_02(self):
180 a = 20
181 (a := a)
182
183 self.assertEqual(a, 20)
184
185 def test_named_expression_assignment_03(self):
186 (total := 1 + 2)
187
188 self.assertEqual(total, 3)
189
190 def test_named_expression_assignment_04(self):
191 (info := (1, 2, 3))
192
193 self.assertEqual(info, (1, 2, 3))
194
195 def test_named_expression_assignment_05(self):
196 (x := 1, 2)
197
198 self.assertEqual(x, 1)
199
200 def test_named_expression_assignment_06(self):
201 (z := (y := (x := 0)))
202
203 self.assertEqual(x, 0)
204 self.assertEqual(y, 0)
205 self.assertEqual(z, 0)
206
207 def test_named_expression_assignment_07(self):
208 (loc := (1, 2))
209
210 self.assertEqual(loc, (1, 2))
211
212 def test_named_expression_assignment_08(self):
213 if spam := "eggs":
214 self.assertEqual(spam, "eggs")
215 else: self.fail("variable was not assigned using named expression")
216
217 def test_named_expression_assignment_09(self):
218 if True and (spam := True):
219 self.assertTrue(spam)
220 else: self.fail("variable was not assigned using named expression")
221
222 def test_named_expression_assignment_10(self):
Joannah Nanjekye075de6c2019-02-01 22:58:43 +0300223 if (match := 10) == 10:
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700224 pass
225 else: self.fail("variable was not assigned using named expression")
226
227 def test_named_expression_assignment_11(self):
228 def spam(a):
229 return a
230 input_data = [1, 2, 3]
231 res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
232
233 self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
234
235 def test_named_expression_assignment_12(self):
236 def spam(a):
237 return a
238 res = [[y := spam(x), x/y] for x in range(1, 5)]
239
240 self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]])
241
242 def test_named_expression_assignment_13(self):
243 length = len(lines := [1, 2])
244
245 self.assertEqual(length, 2)
246 self.assertEqual(lines, [1,2])
247
248 def test_named_expression_assignment_14(self):
249 """
250 Where all variables are positive integers, and a is at least as large
251 as the n'th root of x, this algorithm returns the floor of the n'th
252 root of x (and roughly doubling the number of accurate bits per
Emily Morehouseac190812019-02-01 15:27:38 -0700253 iteration):
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700254 """
255 a = 9
256 n = 2
257 x = 3
258
259 while a > (d := x // a**(n-1)):
260 a = ((n-1)*a + d) // n
261
262 self.assertEqual(a, 1)
263
Emily Morehouseac190812019-02-01 15:27:38 -0700264 def test_named_expression_assignment_15(self):
265 while a := False:
266 pass # This will not run
267
268 self.assertEqual(a, False)
269
Miss Islington (bot)874ff652019-06-22 15:34:03 -0700270 def test_named_expression_assignment_16(self):
271 a, b = 1, 2
272 fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)}
273 self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21})
274
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700275
276class NamedExpressionScopeTest(unittest.TestCase):
277
278 def test_named_expression_scope_01(self):
279 code = """def spam():
280 (a := 5)
281print(a)"""
282
283 with self.assertRaisesRegex(NameError, "name 'a' is not defined"):
284 exec(code, {}, {})
285
286 def test_named_expression_scope_02(self):
287 total = 0
288 partial_sums = [total := total + v for v in range(5)]
289
290 self.assertEqual(partial_sums, [0, 1, 3, 6, 10])
291 self.assertEqual(total, 10)
292
293 def test_named_expression_scope_03(self):
294 containsOne = any((lastNum := num) == 1 for num in [1, 2, 3])
295
296 self.assertTrue(containsOne)
297 self.assertEqual(lastNum, 1)
298
299 def test_named_expression_scope_04(self):
300 def spam(a):
301 return a
302 res = [[y := spam(x), x/y] for x in range(1, 5)]
303
304 self.assertEqual(y, 4)
305
306 def test_named_expression_scope_05(self):
307 def spam(a):
308 return a
309 input_data = [1, 2, 3]
310 res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
311
312 self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
313 self.assertEqual(y, 3)
314
315 def test_named_expression_scope_06(self):
316 res = [[spam := i for i in range(3)] for j in range(2)]
317
318 self.assertEqual(res, [[0, 1, 2], [0, 1, 2]])
319 self.assertEqual(spam, 2)
320
321 def test_named_expression_scope_07(self):
322 len(lines := [1, 2])
323
324 self.assertEqual(lines, [1, 2])
325
326 def test_named_expression_scope_08(self):
327 def spam(a):
328 return a
329
330 def eggs(b):
331 return b * 2
332
333 res = [spam(a := eggs(b := h)) for h in range(2)]
334
335 self.assertEqual(res, [0, 2])
336 self.assertEqual(a, 2)
337 self.assertEqual(b, 1)
338
339 def test_named_expression_scope_09(self):
340 def spam(a):
341 return a
342
343 def eggs(b):
344 return b * 2
345
346 res = [spam(a := eggs(a := h)) for h in range(2)]
347
348 self.assertEqual(res, [0, 2])
349 self.assertEqual(a, 2)
350
351 def test_named_expression_scope_10(self):
352 res = [b := [a := 1 for i in range(2)] for j in range(2)]
353
354 self.assertEqual(res, [[1, 1], [1, 1]])
355 self.assertEqual(a, 1)
356 self.assertEqual(b, [1, 1])
357
358 def test_named_expression_scope_11(self):
359 res = [j := i for i in range(5)]
360
361 self.assertEqual(res, [0, 1, 2, 3, 4])
362 self.assertEqual(j, 4)
363
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700364 def test_named_expression_scope_17(self):
365 b = 0
366 res = [b := i + b for i in range(5)]
367
368 self.assertEqual(res, [0, 1, 3, 6, 10])
369 self.assertEqual(b, 10)
370
371 def test_named_expression_scope_18(self):
372 def spam(a):
373 return a
374
375 res = spam(b := 2)
376
377 self.assertEqual(res, 2)
378 self.assertEqual(b, 2)
379
380 def test_named_expression_scope_19(self):
381 def spam(a):
382 return a
383
384 res = spam((b := 2))
385
386 self.assertEqual(res, 2)
387 self.assertEqual(b, 2)
388
389 def test_named_expression_scope_20(self):
390 def spam(a):
391 return a
392
393 res = spam(a=(b := 2))
394
395 self.assertEqual(res, 2)
396 self.assertEqual(b, 2)
397
398 def test_named_expression_scope_21(self):
399 def spam(a, b):
400 return a + b
401
402 res = spam(c := 2, b=1)
403
404 self.assertEqual(res, 3)
405 self.assertEqual(c, 2)
406
407 def test_named_expression_scope_22(self):
408 def spam(a, b):
409 return a + b
410
411 res = spam((c := 2), b=1)
412
413 self.assertEqual(res, 3)
414 self.assertEqual(c, 2)
415
416 def test_named_expression_scope_23(self):
417 def spam(a, b):
418 return a + b
419
420 res = spam(b=(c := 2), a=1)
421
422 self.assertEqual(res, 3)
423 self.assertEqual(c, 2)
424
425 def test_named_expression_scope_24(self):
426 a = 10
427 def spam():
428 nonlocal a
429 (a := 20)
430 spam()
431
432 self.assertEqual(a, 20)
433
434 def test_named_expression_scope_25(self):
435 ns = {}
436 code = """a = 10
437def spam():
438 global a
439 (a := 20)
440spam()"""
441
442 exec(code, ns, {})
443
444 self.assertEqual(ns["a"], 20)
445
Nick Coghlan6ca03072019-08-26 00:41:47 +1000446 def test_named_expression_variable_reuse_in_comprehensions(self):
447 # The compiler is expected to raise syntax error for comprehension
448 # iteration variables, but should be fine with rebinding of other
449 # names (e.g. globals, nonlocals, other assignment expressions)
450
451 # The cases are all defined to produce the same expected result
452 # Each comprehension is checked at both function scope and module scope
453 rebinding = "[x := i for i in range(3) if (x := i) or not x]"
454 filter_ref = "[x := i for i in range(3) if x or not x]"
455 body_ref = "[x for i in range(3) if (x := i) or not x]"
456 nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]"
457 cases = [
458 ("Rebind global", f"x = 1; result = {rebinding}"),
459 ("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"),
460 ("Filter global", f"x = 1; result = {filter_ref}"),
461 ("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"),
462 ("Body global", f"x = 1; result = {body_ref}"),
463 ("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"),
464 ("Nested global", f"x = 1; result = {nested_ref}"),
465 ("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"),
466 ]
467 for case, code in cases:
468 with self.subTest(case=case):
469 ns = {}
470 exec(code, ns)
471 self.assertEqual(ns["x"], 2)
472 self.assertEqual(ns["result"], [0, 1, 2])
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700473
Miss Islington (bot)927f07c2019-10-14 04:40:15 -0700474 def test_named_expression_global_scope(self):
475 sentinel = object()
476 global GLOBAL_VAR
477 def f():
478 global GLOBAL_VAR
479 [GLOBAL_VAR := sentinel for _ in range(1)]
480 self.assertEqual(GLOBAL_VAR, sentinel)
481 try:
482 f()
483 self.assertEqual(GLOBAL_VAR, sentinel)
484 finally:
485 GLOBAL_VAR = None
486
487 def test_named_expression_global_scope_no_global_keyword(self):
488 sentinel = object()
489 def f():
490 GLOBAL_VAR = None
491 [GLOBAL_VAR := sentinel for _ in range(1)]
492 self.assertEqual(GLOBAL_VAR, sentinel)
493 f()
494 self.assertEqual(GLOBAL_VAR, None)
495
496 def test_named_expression_nonlocal_scope(self):
497 sentinel = object()
498 def f():
499 nonlocal_var = None
500 def g():
501 nonlocal nonlocal_var
502 [nonlocal_var := sentinel for _ in range(1)]
503 g()
504 self.assertEqual(nonlocal_var, sentinel)
505 f()
506
507 def test_named_expression_nonlocal_scope_no_nonlocal_keyword(self):
508 sentinel = object()
509 def f():
510 nonlocal_var = None
511 def g():
512 [nonlocal_var := sentinel for _ in range(1)]
513 g()
514 self.assertEqual(nonlocal_var, None)
515 f()
516
517
Emily Morehouse8f59ee02019-01-24 16:49:56 -0700518if __name__ == "__main__":
519 unittest.main()