blob: cbee7b8b2bde9197885245e802c3b8842f05c3f2 [file] [log] [blame]
Sergei Trofimov4e6afe92015-10-09 09:30:04 +01001import os
2import csv
3import tempfile
4from itertools import chain
5
6from devlib.instrument import Instrument, MeasurementsCsv, CONTINUOUS
7from devlib.exception import HostError
8from devlib.utils.misc import unique
9
10try:
11 from daqpower.client import execute_command, Status
12 from daqpower.config import DeviceConfiguration, ServerConfiguration
13except ImportError, e:
14 execute_command, Status = None, None
15 DeviceConfiguration, ServerConfiguration, ConfigurationError = None, None, None
16 import_error_mesg = e.message
17
18
19class DaqInstrument(Instrument):
20
21 mode = CONTINUOUS
22
23 def __init__(self, target, resistor_values, # pylint: disable=R0914
24 labels=None,
25 host='localhost',
26 port=45677,
27 device_id='Dev1',
28 v_range=2.5,
29 dv_range=0.2,
30 sampling_rate=10000,
31 channel_map=(0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23),
32 ):
33 # pylint: disable=no-member
34 super(DaqInstrument, self).__init__(target)
35 self._need_reset = True
36 if execute_command is None:
37 raise HostError('Could not import "daqpower": {}'.format(import_error_mesg))
38 if labels is None:
39 labels = ['PORT_{}'.format(i) for i in xrange(len(resistor_values))]
40 if len(labels) != len(resistor_values):
41 raise ValueError('"labels" and "resistor_values" must be of the same length')
42 self.server_config = ServerConfiguration(host=host,
43 port=port)
44 result = self.execute('list_devices')
45 if result.status == Status.OK:
46 if device_id not in result.data:
47 raise ValueError('Device "{}" is not found on the DAQ server.'.format(device_id))
48 elif result.status != Status.OKISH:
49 raise HostError('Problem querying DAQ server: {}'.format(result.message))
50
51 self.device_config = DeviceConfiguration(device_id=device_id,
52 v_range=v_range,
53 dv_range=dv_range,
54 sampling_rate=sampling_rate,
55 resistor_values=resistor_values,
56 channel_map=channel_map,
57 labels=labels)
58
59 for label in labels:
60 for kind in ['power', 'voltage']:
61 self.add_channel(label, kind)
62
63 def reset(self, sites=None, kinds=None):
64 super(DaqInstrument, self).reset(sites, kinds)
65 self.execute('close')
66 result = self.execute('configure', config=self.device_config)
67 if not result.status == Status.OK: # pylint: disable=no-member
68 raise HostError(result.message)
69 self._need_reset = False
70
71 def start(self):
72 if self._need_reset:
73 self.reset()
74 self.execute('start')
75
76 def stop(self):
77 self.execute('stop')
78 self._need_reset = True
79
80 def get_data(self, outfile): # pylint: disable=R0914
81 tempdir = tempfile.mkdtemp(prefix='daq-raw-')
82 self.execute('get_data', output_directory=tempdir)
83 raw_file_map = {}
84 for entry in os.listdir(tempdir):
85 site = os.path.splitext(entry)[0]
86 path = os.path.join(tempdir, entry)
87 raw_file_map[site] = path
88
89 active_sites = unique([c.site for c in self.active_channels])
90 file_handles = []
91 try:
92 site_readers = {}
93 for site in active_sites:
94 try:
95 site_file = raw_file_map[site]
96 fh = open(site_file, 'rb')
97 site_readers[site] = csv.reader(fh)
98 file_handles.append(fh)
99 except KeyError:
100 message = 'Could not get DAQ trace for {}; Obtained traces are in {}'
101 raise HostError(message.format(site, tempdir))
102
103 # The first row is the headers
104 channel_order = []
105 for site, reader in site_readers.iteritems():
106 channel_order.extend(['{}_{}'.format(site, kind)
107 for kind in reader.next()])
108
109 def _read_next_rows():
110 parts = []
111 for reader in site_readers.itervalues():
112 try:
113 parts.extend(reader.next())
114 except StopIteration:
115 parts.extend([None, None])
116 return list(chain(parts))
117
118 with open(outfile, 'wb') as wfh:
119 field_names = [c.label for c in self.active_channels]
120 writer = csv.writer(wfh)
121 writer.writerow(field_names)
122 raw_row = _read_next_rows()
123 while any(raw_row):
124 row = [raw_row[channel_order.index(f)] for f in field_names]
125 writer.writerow(row)
126 raw_row = _read_next_rows()
127
128 return MeasurementsCsv(outfile, self.active_channels)
129 finally:
130 for fh in file_handles:
131 fh.close()
132
133 def teardown(self):
134 self.execute('close')
135
136 def execute(self, command, **kwargs):
137 return execute_command(self.server_config, command, **kwargs)
138