blob: 25f98ad318dedc393357742b551c3bbb8dc3a16d [file] [log] [blame]
Ralph Nathand3472e12016-02-12 16:24:15 -08001# Copyright (c) 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Server side audio utilities functions for Brillo."""
6
7import contextlib
8import logging
9import numpy
10import os
11import struct
12import subprocess
13import tempfile
14import wave
15
16from autotest_lib.client.common_lib import error
17
18
19_BITS_PER_BYTE=8
20
21# Thresholds used when comparing files.
22#
23# The frequency threshold used when comparing files. The frequency of the
24# recorded audio has to be within _FREQUENCY_THRESHOLD percent of the frequency
25# of the original audio.
26_FREQUENCY_THRESHOLD = 0.01
27# Noise threshold controls how much noise is allowed as a fraction of the
28# magnitude of the peak frequency after taking an FFT. The power of all the
29# other frequencies in the signal should be within _FFT_NOISE_THRESHOLD percent
30# of the power of the main frequency.
31_FFT_NOISE_THRESHOLD = 0.05
32
Ralph Nathan4dcc8912016-02-29 09:33:38 -080033# Command used to encode audio. If you want to test with something different,
34# this should be changed.
35_ENCODING_CMD = 'sox'
36
Ralph Nathand3472e12016-02-12 16:24:15 -080037
38def extract_wav_frames(wave_file):
39 """Extract all frames from a WAV file.
40
Ralph Nathan4dcc8912016-02-29 09:33:38 -080041 @param wave_file: A Wave_read object representing a WAV file opened for
42 reading.
Ralph Nathand3472e12016-02-12 16:24:15 -080043
44 @return: A list containing the frames in the WAV file.
45 """
46 num_frames = wave_file.getnframes()
47 sample_width = wave_file.getsampwidth()
48 if sample_width == 1:
49 fmt = '%iB' # Read 1 byte.
50 elif sample_width == 2:
51 fmt = '%ih' # Read 2 bytes.
52 elif sample_width == 4:
53 fmt = '%ii' # Read 4 bytes.
54 else:
55 raise ValueError('Unsupported sample width')
56 frames = list(struct.unpack(fmt % num_frames * wave_file.getnchannels(),
57 wave_file.readframes(num_frames)))
58
59 # Since 8-bit PCM is unsigned with an offset of 128, we subtract the offset
60 # to make it signed since the rest of the code assumes signed numbers.
61 if sample_width == 1:
62 frames = [val - 128 for val in frames]
63
64 return frames
65
66
67def check_wav_file(filename, num_channels=None, sample_rate=None,
68 sample_width=None):
69 """Checks a WAV file and returns its peak PCM values.
70
71 @param filename: Input WAV file to analyze.
72 @param num_channels: Number of channels to expect (None to not check).
73 @param sample_rate: Sample rate to expect (None to not check).
74 @param sample_width: Sample width to expect (None to not check).
75
76 @return A list of the absolute maximum PCM values for each channel in the
77 WAV file.
78
79 @raise ValueError: Failed to process the WAV file or validate an attribute.
80 """
81 chk_file = None
82 try:
83 chk_file = wave.open(filename, 'r')
84 if num_channels is not None and chk_file.getnchannels() != num_channels:
85 raise ValueError('Expected %d channels but got %d instead.',
86 num_channels, chk_file.getnchannels())
87 if sample_rate is not None and chk_file.getframerate() != sample_rate:
88 raise ValueError('Expected sample rate %d but got %d instead.',
89 sample_rate, chk_file.getframerate())
90 if sample_width is not None and chk_file.getsampwidth() != sample_width:
91 raise ValueError('Expected sample width %d but got %d instead.',
92 sample_width, chk_file.getsampwidth())
93 frames = extract_wav_frames(chk_file)
94 except wave.Error as e:
95 raise ValueError('Error processing WAV file: %s' % e)
96 finally:
97 if chk_file is not None:
98 chk_file.close()
99
100 peaks = []
101 for i in range(chk_file.getnchannels()):
102 peaks.append(max(map(abs, frames[i::chk_file.getnchannels()])))
103 return peaks;
104
105
106def generate_sine_file(host, num_channels, sample_rate, sample_width,
Ralph Nathan4dcc8912016-02-29 09:33:38 -0800107 duration_secs, sine_frequency, temp_dir,
108 file_format='wav'):
Ralph Nathand3472e12016-02-12 16:24:15 -0800109 """Generate a sine file and push it to the DUT.
110
111 @param host: An object representing the DUT.
112 @param num_channels: Number of channels to use.
113 @param sample_rate: Sample rate to use for sine wave generation.
114 @param sample_width: Sample width to use for sine wave generation.
115 @param duration_secs: Duration in seconds to generate sine wave for.
116 @param sine_frequency: Frequency to generate sine wave at.
117 @param temp_dir: A temporary directory on the host.
Ralph Nathan4dcc8912016-02-29 09:33:38 -0800118 @param file_format: A string representing the encoding for the audio file.
Ralph Nathand3472e12016-02-12 16:24:15 -0800119
120 @return A tuple of the filename on the server and the DUT.
121 """;
122 _, local_filename = tempfile.mkstemp(
Ralph Nathan4dcc8912016-02-29 09:33:38 -0800123 prefix='sine-', suffix='.' + file_format, dir=temp_dir)
Ralph Nathand3472e12016-02-12 16:24:15 -0800124 if sample_width == 1:
Ralph Nathan4dcc8912016-02-29 09:33:38 -0800125 byte_format = '-e unsigned'
Ralph Nathand3472e12016-02-12 16:24:15 -0800126 else:
Ralph Nathan4dcc8912016-02-29 09:33:38 -0800127 byte_format = '-e signed'
Ralph Nathand3472e12016-02-12 16:24:15 -0800128 gen_file_cmd = ('sox -n -t wav -c %d %s -b %d -r %d %s synth %d sine %d '
Ralph Nathan4dcc8912016-02-29 09:33:38 -0800129 'vol 0.9' % (num_channels, byte_format,
Ralph Nathand3472e12016-02-12 16:24:15 -0800130 sample_width * _BITS_PER_BYTE, sample_rate,
131 local_filename, duration_secs, sine_frequency))
132 logging.info('Command to generate sine wave: %s', gen_file_cmd)
133 subprocess.call(gen_file_cmd, shell=True)
Ralph Nathan4dcc8912016-02-29 09:33:38 -0800134 if file_format != 'wav':
135 # Convert the file to the appropriate format.
136 logging.info('Converting file to %s', file_format)
137 _, local_encoded_filename = tempfile.mkstemp(
138 prefix='sine-', suffix='.' + file_format, dir=temp_dir)
139 cvt_file_cmd = '%s %s %s' % (_ENCODING_CMD, local_filename,
140 local_encoded_filename)
141 logging.info('Command to convert file: %s', cvt_file_cmd)
142 subprocess.call(cvt_file_cmd, shell=True)
143 else:
144 local_encoded_filename = local_filename
145 dut_tmp_dir = '/data'
146 remote_filename = os.path.join(dut_tmp_dir, 'sine.' + file_format)
Ralph Nathand3472e12016-02-12 16:24:15 -0800147 logging.info('Send file to DUT.')
148 # TODO(ralphnathan): Find a better place to put this file once the SELinux
149 # issues are resolved.
Ralph Nathand3472e12016-02-12 16:24:15 -0800150 logging.info('remote_filename %s', remote_filename)
Ralph Nathan4dcc8912016-02-29 09:33:38 -0800151 host.send_file(local_encoded_filename, remote_filename)
Ralph Nathand3472e12016-02-12 16:24:15 -0800152 return local_filename, remote_filename
153
154
155def _is_outside_frequency_threshold(freq_reference, freq_rec):
156 """Compares the frequency of the recorded audio with the reference audio.
157
158 This function checks to see if the frequencies corresponding to the peak
159 FFT values are similiar meaning that the dominant frequency in the audio
160 signal is the same for the recorded audio as that in the audio played.
161
162 @param req_reference: The dominant frequency in the reference audio file.
163 @param freq_rec: The dominant frequency in the recorded audio file.
164
165 @return: True is freq_rec is with _FREQUENCY_THRESHOLD percent of
166 freq_reference.
167 """
168 ratio = float(freq_rec) / freq_reference
169 if ratio > 1 + _FREQUENCY_THRESHOLD or ratio < 1 - _FREQUENCY_THRESHOLD:
170 return True
171 return False
172
173
174def _compare_frames(reference_file_frames, rec_file_frames, num_channels,
175 sample_rate):
176 """Compares audio frames from the reference file and the recorded file.
177
178 This method checks for two things:
179 1. That the main frequency is the same in both the files. This is done
180 using the FFT and observing the frequency corresponding to the
181 peak.
182 2. That there is no other dominant frequency in the recorded file.
183 This is done by sweeping the frequency domain and checking that the
184 frequency is always less than _FFT_NOISE_THRESHOLD percentage of
185 the peak.
186
187 The key assumption here is that the reference audio file contains only
188 one frequency.
189
190 @param reference_file_frames: Audio frames from the reference file.
191 @param rec_file_frames: Audio frames from the recorded file.
192 @param num_channels: Number of channels in the files.
193 @param sample_rate: Sample rate of the files.
194
195 @raise error.TestFail: The frequency of the recorded signal doesn't
196 match that of the reference signal.
197 @raise error.TestFail: There is too much noise in the recorded signal.
198 """
199 for channel in range(num_channels):
200 reference_data = reference_file_frames[channel::num_channels]
201 rec_data = rec_file_frames[channel::num_channels]
202
203 # Get fft and frequencies corresponding to the fft values.
204 fft_reference = numpy.fft.rfft(reference_data)
205 fft_rec = numpy.fft.rfft(rec_data)
206 fft_freqs_reference = numpy.fft.rfftfreq(len(reference_data),
207 1.0 / sample_rate)
208 fft_freqs_rec = numpy.fft.rfftfreq(len(rec_data), 1.0 / sample_rate)
209
210 # Get frequency at highest peak.
211 freq_reference = fft_freqs_reference[
212 numpy.argmax(numpy.abs(fft_reference))]
213 abs_fft_rec = numpy.abs(fft_rec)
214 freq_rec = fft_freqs_rec[numpy.argmax(abs_fft_rec)]
215
216 # Compare the two frequencies.
217 logging.info('Golden frequency of channel %i is %f', channel,
218 freq_reference)
219 logging.info('Recorded frequency of channel %i is %f', channel,
220 freq_rec)
221 if _is_outside_frequency_threshold(freq_reference, freq_rec):
222 raise error.TestFail('The recorded audio frequency does not match '
223 'that of the audio played.')
224
225 # Check for noise in the frequency domain.
226 fft_rec_peak_val = numpy.max(abs_fft_rec)
227 noise_detected = False
228 for fft_index, fft_val in enumerate(abs_fft_rec):
229 if _is_outside_frequency_threshold(freq_reference, freq_rec):
230 # If the frequency exceeds _FFT_NOISE_THRESHOLD, then fail.
231 if fft_val > _FFT_NOISE_THRESHOLD * fft_rec_peak_val:
232 logging.warning('Unexpected frequency peak detected at %f '
233 'Hz.', fft_freqs_rec[fft_index])
234 noise_detected = True
235
236 if noise_detected:
237 raise error.TestFail('Signal is noiser than expected.')
238
239
240def compare_file(reference_audio_filename, test_audio_filename):
241 """Compares the recorded audio file to the reference audio file.
242
243 @param reference_audio_filename : Reference audio file containing the
244 reference signal.
245 @param test_audio_filename: Audio file containing audio captured from
246 the test.
247 """
248 with contextlib.closing(wave.open(reference_audio_filename,
249 'rb')) as reference_file:
250 with contextlib.closing(wave.open(test_audio_filename,
251 'rb')) as rec_file:
252 # Extract data from files.
253 reference_file_frames = extract_wav_frames(reference_file)
254 rec_file_frames = extract_wav_frames(rec_file)
255
256 num_channels = reference_file.getnchannels()
257 _compare_frames(reference_file_frames, rec_file_frames,
258 reference_file.getnchannels(),
259 reference_file.getframerate())