blob: 6d683b26deb49d8126b34a421e8539280dc0c958 [file] [log] [blame]
mblighe7dc9d32009-01-21 19:24:38 +00001# Copyright 2008 Google Inc, Martin J. Bligh <mbligh@google.com>,
2# Benjamin Poirier, Ryan Stutsman
3# Released under the GPL v2
4"""
5Miscellaneous small functions.
6
7DO NOT import this file directly - it is mixed in by server/utils.py,
8import that instead
9"""
10
11import atexit, os, re, shutil, textwrap, sys, tempfile, types
12
mbligh999fb132010-04-23 17:22:03 +000013from autotest_lib.client.common_lib import barrier, utils
mblighe7dc9d32009-01-21 19:24:38 +000014from autotest_lib.server import subcommand
15
16
17# A dictionary of pid and a list of tmpdirs for that pid
18__tmp_dirs = {}
19
20
mblighe7dc9d32009-01-21 19:24:38 +000021def scp_remote_escape(filename):
22 """
23 Escape special characters from a filename so that it can be passed
24 to scp (within double quotes) as a remote file.
25
26 Bis-quoting has to be used with scp for remote files, "bis-quoting"
27 as in quoting x 2
28 scp does not support a newline in the filename
29
30 Args:
31 filename: the filename string to escape.
32
33 Returns:
34 The escaped filename string. The required englobing double
35 quotes are NOT added and so should be added at some point by
36 the caller.
37 """
38 escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}'
39
40 new_name= []
41 for char in filename:
42 if char in escape_chars:
43 new_name.append("\\%s" % (char,))
44 else:
45 new_name.append(char)
46
mbligh277a0e42009-07-11 00:11:45 +000047 return utils.sh_escape("".join(new_name))
mblighe7dc9d32009-01-21 19:24:38 +000048
49
50def get(location, local_copy = False):
51 """Get a file or directory to a local temporary directory.
52
53 Args:
54 location: the source of the material to get. This source may
55 be one of:
56 * a local file or directory
57 * a URL (http or ftp)
58 * a python file-like object
59
60 Returns:
61 The location of the file or directory where the requested
62 content was saved. This will be contained in a temporary
63 directory on the local host. If the material to get was a
64 directory, the location will contain a trailing '/'
65 """
66 tmpdir = get_tmp_dir()
67
68 # location is a file-like object
69 if hasattr(location, "read"):
70 tmpfile = os.path.join(tmpdir, "file")
71 tmpfileobj = file(tmpfile, 'w')
72 shutil.copyfileobj(location, tmpfileobj)
73 tmpfileobj.close()
74 return tmpfile
75
76 if isinstance(location, types.StringTypes):
77 # location is a URL
78 if location.startswith('http') or location.startswith('ftp'):
79 tmpfile = os.path.join(tmpdir, os.path.basename(location))
80 utils.urlretrieve(location, tmpfile)
81 return tmpfile
82 # location is a local path
83 elif os.path.exists(os.path.abspath(location)):
84 if not local_copy:
85 if os.path.isdir(location):
86 return location.rstrip('/') + '/'
87 else:
88 return location
89 tmpfile = os.path.join(tmpdir, os.path.basename(location))
90 if os.path.isdir(location):
91 tmpfile += '/'
92 shutil.copytree(location, tmpfile, symlinks=True)
93 return tmpfile
94 shutil.copyfile(location, tmpfile)
95 return tmpfile
96 # location is just a string, dump it to a file
97 else:
98 tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir)
99 tmpfileobj = os.fdopen(tmpfd, 'w')
100 tmpfileobj.write(location)
101 tmpfileobj.close()
102 return tmpfile
103
104
105def get_tmp_dir():
106 """Return the pathname of a directory on the host suitable
107 for temporary file storage.
108
109 The directory and its content will be deleted automatically
110 at the end of the program execution if they are still present.
111 """
112 dir_name = tempfile.mkdtemp(prefix="autoserv-")
113 pid = os.getpid()
114 if not pid in __tmp_dirs:
115 __tmp_dirs[pid] = []
116 __tmp_dirs[pid].append(dir_name)
117 return dir_name
118
119
120def __clean_tmp_dirs():
121 """Erase temporary directories that were created by the get_tmp_dir()
122 function and that are still present.
123 """
124 pid = os.getpid()
125 if pid not in __tmp_dirs:
126 return
127 for dir in __tmp_dirs[pid]:
128 try:
129 shutil.rmtree(dir)
130 except OSError, e:
131 if e.errno == 2:
132 pass
133 __tmp_dirs[pid] = []
134atexit.register(__clean_tmp_dirs)
135subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs())
136
137
138def unarchive(host, source_material):
139 """Uncompress and untar an archive on a host.
140
141 If the "source_material" is compresses (according to the file
142 extension) it will be uncompressed. Supported compression formats
143 are gzip and bzip2. Afterwards, if the source_material is a tar
144 archive, it will be untarred.
145
146 Args:
147 host: the host object on which the archive is located
148 source_material: the path of the archive on the host
149
150 Returns:
151 The file or directory name of the unarchived source material.
152 If the material is a tar archive, it will be extracted in the
153 directory where it is and the path returned will be the first
154 entry in the archive, assuming it is the topmost directory.
155 If the material is not an archive, nothing will be done so this
156 function is "harmless" when it is "useless".
157 """
158 # uncompress
159 if (source_material.endswith(".gz") or
160 source_material.endswith(".gzip")):
mbligh277a0e42009-07-11 00:11:45 +0000161 host.run('gunzip "%s"' % (utils.sh_escape(source_material)))
mblighe7dc9d32009-01-21 19:24:38 +0000162 source_material= ".".join(source_material.split(".")[:-1])
163 elif source_material.endswith("bz2"):
mbligh277a0e42009-07-11 00:11:45 +0000164 host.run('bunzip2 "%s"' % (utils.sh_escape(source_material)))
mblighe7dc9d32009-01-21 19:24:38 +0000165 source_material= ".".join(source_material.split(".")[:-1])
166
167 # untar
168 if source_material.endswith(".tar"):
169 retval= host.run('tar -C "%s" -xvf "%s"' % (
mbligh277a0e42009-07-11 00:11:45 +0000170 utils.sh_escape(os.path.dirname(source_material)),
171 utils.sh_escape(source_material),))
mblighe7dc9d32009-01-21 19:24:38 +0000172 source_material= os.path.join(os.path.dirname(source_material),
173 retval.stdout.split()[0])
174
175 return source_material
176
177
178def get_server_dir():
179 path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__)
180 return os.path.abspath(path)
181
182
183def find_pid(command):
184 for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'):
185 (pid, cmd) = line.split(None, 1)
186 if re.search(command, cmd):
187 return int(pid)
188 return None
189
190
191def nohup(command, stdout='/dev/null', stderr='/dev/null', background=True,
192 env = {}):
193 cmd = ' '.join(key+'='+val for key, val in env.iteritems())
194 cmd += ' nohup ' + command
195 cmd += ' > %s' % stdout
196 if stdout == stderr:
197 cmd += ' 2>&1'
198 else:
199 cmd += ' 2> %s' % stderr
200 if background:
201 cmd += ' &'
202 utils.system(cmd)
203
204
205def default_mappings(machines):
206 """
207 Returns a simple mapping in which all machines are assigned to the
208 same key. Provides the default behavior for
209 form_ntuples_from_machines. """
210 mappings = {}
211 failures = []
212
213 mach = machines[0]
214 mappings['ident'] = [mach]
215 if len(machines) > 1:
216 machines = machines[1:]
217 for machine in machines:
218 mappings['ident'].append(machine)
219
220 return (mappings, failures)
221
222
223def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings):
224 """Returns a set of ntuples from machines where the machines in an
225 ntuple are in the same mapping, and a set of failures which are
226 (machine name, reason) tuples."""
227 ntuples = []
228 (mappings, failures) = mapping_func(machines)
229
230 # now run through the mappings and create n-tuples.
231 # throw out the odd guys out
232 for key in mappings:
233 key_machines = mappings[key]
234 total_machines = len(key_machines)
235
236 # form n-tuples
237 while len(key_machines) >= n:
238 ntuples.append(key_machines[0:n])
239 key_machines = key_machines[n:]
240
241 for mach in key_machines:
242 failures.append((mach, "machine can not be tupled"))
243
244 return (ntuples, failures)
245
246
247def parse_machine(machine, user = 'root', port = 22, password = ''):
248 """
249 Parse the machine string user:pass@host:port and return it separately,
250 if the machine string is not complete, use the default parameters
251 when appropriate.
252 """
253
254 user = user
255 port = port
256 password = password
257
258 if re.search('@', machine):
259 machine = machine.split('@')
260
261 if re.search(':', machine[0]):
262 machine[0] = machine[0].split(':')
263 user = machine[0][0]
264 password = machine[0][1]
265
266 else:
267 user = machine[0]
268
269 if re.search(':', machine[1]):
270 machine[1] = machine[1].split(':')
271 hostname = machine[1][0]
272 port = int(machine[1][1])
273
274 else:
275 hostname = machine[1]
276
277 elif re.search(':', machine):
278 machine = machine.split(':')
279 hostname = machine[0]
280 port = int(machine[1])
281
282 else:
283 hostname = machine
284
285 return hostname, user, password, port
286
287
288def get_public_key():
289 """
290 Return a valid string ssh public key for the user executing autoserv or
291 autotest. If there's no DSA or RSA public key, create a DSA keypair with
292 ssh-keygen and return it.
293 """
294
mblighc22c55f2009-05-13 21:35:27 +0000295 ssh_conf_path = os.path.expanduser('~/.ssh')
mblighe7dc9d32009-01-21 19:24:38 +0000296
297 dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub')
298 dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa')
299
300 rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub')
301 rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa')
302
303 has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \
304 os.path.isfile(dsa_private_key_path)
305 has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \
306 os.path.isfile(rsa_private_key_path)
307
308 if has_dsa_keypair:
309 print 'DSA keypair found, using it'
310 public_key_path = dsa_public_key_path
311
312 elif has_rsa_keypair:
313 print 'RSA keypair found, using it'
314 public_key_path = rsa_public_key_path
315
316 else:
317 print 'Neither RSA nor DSA keypair found, creating DSA ssh key pair'
318 utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path)
319 public_key_path = dsa_public_key_path
320
321 public_key = open(public_key_path, 'r')
322 public_key_str = public_key.read()
323 public_key.close()
324
325 return public_key_str
mbligh999fb132010-04-23 17:22:03 +0000326
327
328def get_sync_control_file(control, host_name, host_num,
329 instance, num_jobs, port_base=63100):
330 """
331 This function is used when there is a need to run more than one
332 job simultaneously starting exactly at the same time. It basically returns
333 a modified control file (containing the synchronization code prepended)
334 whenever it is ready to run the control file. The synchronization
335 is done using barriers to make sure that the jobs start at the same time.
336
337 Here is how the synchronization is done to make sure that the tests
338 start at exactly the same time on the client.
339 sc_bar is a server barrier and s_bar, c_bar are the normal barriers
340
341 Job1 Job2 ...... JobN
342 Server: | sc_bar
343 Server: | s_bar ...... s_bar
344 Server: | at.run() at.run() ...... at.run()
345 ----------|------------------------------------------------------
346 Client | sc_bar
347 Client | c_bar c_bar ...... c_bar
348 Client | <run test> <run test> ...... <run test>
349
350 @param control: The control file which to which the above synchronization
351 code will be prepended.
352 @param host_name: The host name on which the job is going to run.
353 @param host_num: (non negative) A number to identify the machine so that
354 we have different sets of s_bar_ports for each of the machines.
355 @param instance: The number of the job
356 @param num_jobs: Total number of jobs that are going to run in parallel
357 with this job starting at the same time.
358 @param port_base: Port number that is used to derive the actual barrier
359 ports.
360
361 @returns The modified control file.
362 """
363 sc_bar_port = port_base
364 c_bar_port = port_base
365 if host_num < 0:
366 print "Please provide a non negative number for the host"
367 return None
368 s_bar_port = port_base + 1 + host_num # The set of s_bar_ports are
369 # the same for a given machine
370
371 sc_bar_timeout = 180
372 s_bar_timeout = c_bar_timeout = 120
373
374 # The barrier code snippet is prepended into the conrol file
375 # dynamically before at.run() is called finally.
376 control_new = []
377
378 # jobid is the unique name used to identify the processes
379 # trying to reach the barriers
380 jobid = "%s#%d" % (host_name, instance)
381
382 rendv = []
383 # rendvstr is a temp holder for the rendezvous list of the processes
384 for n in range(num_jobs):
385 rendv.append("'%s#%d'" % (host_name, n))
386 rendvstr = ",".join(rendv)
387
388 if instance == 0:
389 # Do the setup and wait at the server barrier
390 # Clean up the tmp and the control dirs for the first instance
391 control_new.append('if os.path.exists(job.tmpdir):')
392 control_new.append("\t system('umount -f %s > /dev/null"
393 "2> /dev/null' % job.tmpdir,"
394 "ignore_status=True)")
395 control_new.append("\t system('rm -rf ' + job.tmpdir)")
396 control_new.append(
397 'b0 = job.barrier("%s", "sc_bar", %d, port=%d)'
398 % (jobid, sc_bar_timeout, sc_bar_port))
399 control_new.append(
400 'b0.rendezvous_servers("PARALLEL_MASTER", "%s")'
401 % jobid)
402
403 elif instance == 1:
404 # Wait at the server barrier to wait for instance=0
405 # process to complete setup
406 b0 = barrier.barrier("PARALLEL_MASTER", "sc_bar", sc_bar_timeout,
407 port=sc_bar_port)
408 b0.rendezvous_servers("PARALLEL_MASTER", jobid)
409
410 if(num_jobs > 2):
411 b1 = barrier.barrier(jobid, "s_bar", s_bar_timeout,
412 port=s_bar_port)
413 b1.rendezvous(rendvstr)
414
415 else:
416 # For the rest of the clients
417 b2 = barrier.barrier(jobid, "s_bar", s_bar_timeout, port=s_bar_port)
418 b2.rendezvous(rendvstr)
419
420 # Client side barrier for all the tests to start at the same time
421 control_new.append('b1 = job.barrier("%s", "c_bar", %d, port=%d)'
422 % (jobid, c_bar_timeout, c_bar_port))
423 control_new.append("b1.rendezvous(%s)" % rendvstr)
424
425 # Stick in the rest of the control file
426 control_new.append(control)
427
428 return "\n".join(control_new)