blob: ed6c16d9566d19dae7b388c48eed791663e1b6a0 [file] [log] [blame]
Igor Murashkin03e5b052019-10-03 16:39:50 -07001// Copyright (C) 2019 The Android Open Source Project
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef IORAP_SRC_DB_MODELS_H_
16#define IORAP_SRC_DB_MODELS_H_
17
18#include <android-base/logging.h>
19
20#include <optional>
21#include <ostream>
22#include <string>
23#include <sstream>
24#include <type_traits>
25
26#include <sqlite3.h>
27
28namespace iorap::db {
29
30struct SqliteDbDeleter {
31 void operator()(sqlite3* db) {
32 if (db != nullptr) {
33 LOG(VERBOSE) << "sqlite3_close for: " << db;
34 sqlite3_close(db);
35 }
36 }
37};
38
39class DbHandle {
40 public:
41 // Take over ownership of sqlite3 db.
42 explicit DbHandle(sqlite3* db)
43 : db_{std::shared_ptr<sqlite3>{db, SqliteDbDeleter{}}},
44 mutex_{std::make_shared<std::mutex>()} {
45 }
46
47 sqlite3* get() {
48 return db_.get();
49 }
50
51 operator sqlite3*() {
52 return db_.get();
53 }
54
55 std::mutex& mutex() {
56 return *mutex_.get();
57 }
58
59 private:
60 std::shared_ptr<sqlite3> db_;
61 std::shared_ptr<std::mutex> mutex_;
62};
63
64class ScopedLockDb {
65 public:
66 ScopedLockDb(std::mutex& mutex) : mutex(mutex) {
67 mutex.lock();
68 }
69
70 ScopedLockDb(DbHandle& handle) : ScopedLockDb(handle.mutex()) {
71 }
72
73 ~ScopedLockDb() {
74 mutex.unlock();
75 }
76 private:
77 std::mutex& mutex;
78};
79
80class DbStatement {
81 public:
82 template <typename ... Args>
83 static DbStatement Prepare(DbHandle db, const std::string& sql, Args&&... args) {
84 return Prepare(db, sql.c_str(), std::forward<Args>(args)...);
85 }
86
87 template <typename ... Args>
88 static DbStatement Prepare(DbHandle db, const char* sql, Args&&... args) {
89 DCHECK(db.get() != nullptr);
90 DCHECK(sql != nullptr);
91
92 // LOG(VERBOSE) << "Prepare DB=" << db.get();
93
94 sqlite3_stmt* stmt = nullptr;
95 int rc = sqlite3_prepare_v2(db.get(), sql, -1, /*out*/&stmt, nullptr);
96
97 DbStatement db_stmt{db, stmt};
98 DCHECK(db_stmt.CheckOk(rc)) << sql;
99 db_stmt.BindAll(std::forward<Args>(args)...);
100
101 return db_stmt;
102 }
103
104 DbStatement(DbHandle db, sqlite3_stmt* stmt) : db_(db), stmt_(stmt) {
105 }
106
107 sqlite3_stmt* get() {
108 return stmt_;
109 }
110
111 DbHandle db() {
112 return db_;
113 }
114
115 template <typename T, typename ... Args>
116 void BindAll(T&& arg, Args&&... args) {
117 Bind(std::forward<T>(arg));
118 BindAll(std::forward<Args>(args)...);
119 }
120
121 void BindAll() {}
122
123 template <typename T>
124 void Bind(const std::optional<T>& value) {
125 if (value) {
126 Bind(*value);
127 } else {
128 BindNull();
129 }
130 }
131
132 void Bind(bool value) {
133 CheckOk(sqlite3_bind_int(stmt_, bind_counter_++, value));
134 }
135
136 void Bind(int value) {
137 CheckOk(sqlite3_bind_int(stmt_, bind_counter_++, value));
138 }
139
140 void Bind(uint64_t value) {
141 CheckOk(sqlite3_bind_int64(stmt_, bind_counter_++, static_cast<int64_t>(value)));
142 }
143
144 void Bind(const char* value) {
145 if (value != nullptr) {
146 //sqlite3_bind_text(stmt_, /*index*/bind_counter_++, value, -1, SQLITE_STATIC);
147 CheckOk(sqlite3_bind_text(stmt_, /*index*/bind_counter_++, value, -1, SQLITE_TRANSIENT));
148 } else {
149 BindNull();
150 }
151 }
152
153 void Bind(const std::string& value) {
154 Bind(value.c_str());
155 }
156
157 template <typename E, typename = std::enable_if_t<std::is_enum_v<E>>>
158 void Bind(E value) {
159 Bind(static_cast<std::underlying_type_t<E>>(value));
160 }
161
162 void BindNull() {
163 CheckOk(sqlite3_bind_null(stmt_, bind_counter_++));
164 }
165
166 int Step() {
167 ++step_counter_;
168 return sqlite3_step(stmt_);
169 }
170
171 bool Step(int expected) {
172 int rc = Step();
173 if (rc != expected) {
174 LOG(ERROR) << "SQLite error: " << sqlite3_errmsg(db_.get());
175 return false;
176 }
177 return true;
178 }
179
180 template <typename T, typename ... Args>
181 void ColumnAll(T& first, Args&... rest) {
182 Column(first);
183 ColumnAll(rest...);
184 }
185
186 void ColumnAll() {}
187
188 template <typename T>
189 void Column(std::optional<T>& value) {
190 T tmp;
191 Column(/*out*/tmp);
192
193 if (!tmp) { // disambiguate 0 and NULL
194 const unsigned char* text = sqlite3_column_text(stmt_, column_counter_ - 1);
195 if (text == nullptr) {
196 value = std::nullopt;
197 return;
198 }
199 }
200 value = std::move(tmp);
201 }
202
203 template <typename E, typename = std::enable_if_t<std::is_enum_v<E>>>
204 void Column(E& value) {
205 std::underlying_type_t<E> tmp;
206 Column(/*out*/tmp);
207 value = static_cast<E>(tmp);
208 }
209
210 void Column(bool& value) {
211 value = sqlite3_column_int(stmt_, column_counter_++);
212 }
213
214 void Column(int& value) {
215 value = sqlite3_column_int(stmt_, column_counter_++);
216 }
217
218 void Column(uint64_t& value) {
219 value = static_cast<uint64_t>(sqlite3_column_int64(stmt_, column_counter_++));
220 }
221
222 void Column(std::string& value) {
223 const unsigned char* text = sqlite3_column_text(stmt_, column_counter_++);
224 value = std::string{reinterpret_cast<const char*>(text)};
225 }
226
227 const char* ExpandedSql() {
228 char* p = sqlite3_expanded_sql(stmt_);
229 if (p == nullptr) {
230 return "(nullptr)";
231 }
232 return p;
233 }
234
235 const char* Sql() {
236 const char* p = sqlite3_sql(stmt_);
237 if (p == nullptr) {
238 return "(nullptr)";
239 }
240 return p;
241 }
242
243
244 DbStatement(DbStatement&& other)
245 : db_{other.db_}, stmt_{other.stmt_}, bind_counter_{other.bind_counter_},
246 step_counter_{other.step_counter_} {
247 other.db_ = DbHandle{nullptr};
248 other.stmt_ = nullptr;
249 }
250
251 ~DbStatement() {
252 if (stmt_ != nullptr) {
253 DCHECK_GT(step_counter_, 0) << " forgot to call Step()?";
254 sqlite3_finalize(stmt_);
255 }
256 }
257
258 private:
259 bool CheckOk(int rc, int expected = SQLITE_OK) {
260 if (rc != expected) {
261 LOG(ERROR) << "Got error for SQL query: '" << Sql() << "'"
262 << ", expanded: '" << ExpandedSql() << "'";
263 LOG(ERROR) << "Failed SQLite api call (" << rc << "): " << sqlite3_errstr(rc);
264 }
265 return rc == expected;
266 }
267
268 DbHandle db_;
269 sqlite3_stmt* stmt_;
270 int bind_counter_ = 1;
271 int step_counter_ = 0;
272 int column_counter_ = 0;
273};
274
275class DbQueryBuilder {
276 public:
277 // Returns the row ID that was inserted last.
278 template <typename... Args>
279 static std::optional<int> Insert(DbHandle db, const std::string& sql, Args&&... args) {
280 ScopedLockDb lock{db};
281
282 sqlite3_int64 last_rowid = sqlite3_last_insert_rowid(db.get());
283 DbStatement stmt = DbStatement::Prepare(db, sql, std::forward<Args>(args)...);
284
285 if (!stmt.Step(SQLITE_DONE)) {
286 return std::nullopt;
287 }
288
289 last_rowid = sqlite3_last_insert_rowid(db.get());
290 DCHECK_GT(last_rowid, 0);
291
292 return static_cast<int>(last_rowid);
293 }
294
295 template <typename... Args>
296 static bool SelectOnce(DbStatement& stmt, Args&... args) {
297 int rc = stmt.Step();
298 switch (rc) {
299 case SQLITE_ROW:
300 stmt.ColumnAll(/*out*/args...);
301 return true;
302 case SQLITE_DONE:
303 return false;
304 default:
305 LOG(ERROR) << "Failed to step (" << rc << "): " << sqlite3_errmsg(stmt.db());
306 return false;
307 }
308 }
309};
310
311class Model {
312 public:
313 DbHandle db() const {
314 return db_;
315 }
316
317 Model(DbHandle db) : db_{db} {
318 }
319
320 private:
321 DbHandle db_;
322};
323
324class SchemaModel : public Model {
325 public:
326 static SchemaModel GetOrCreate(std::string location) {
327 int rc = sqlite3_config(SQLITE_CONFIG_LOG, ErrorLogCallback, /*data*/nullptr);
328
329 if (rc != SQLITE_OK) {
330 LOG(FATAL) << "Failed to configure logging";
331 }
332
333 sqlite3* db = nullptr;
334 if (location != ":memory:") {
335 // Try to open DB if it already exists.
336 rc = sqlite3_open_v2(location.c_str(), /*out*/&db, SQLITE_OPEN_READWRITE, /*vfs*/nullptr);
337
338 if (rc == SQLITE_OK) {
339 LOG(INFO) << "Opened existing database at '" << location << "'";
340 return SchemaModel{DbHandle{db}, location};
341 }
342 }
343
344 // Create a new DB if one didn't exist already.
345 rc = sqlite3_open(location.c_str(), /*out*/&db);
346
347 if (rc != SQLITE_OK) {
348 LOG(FATAL) << "Failed to open DB: " << sqlite3_errmsg(db);
349 }
350
351 SchemaModel schema{DbHandle{db}, location};
352 schema.Reinitialize();
353 // TODO: migrate versions upwards when we rev the schema version
354
355 int old_version = schema.Version();
356 LOG(VERBOSE) << "Loaded schema version: " << old_version;
357
358 return schema;
359 }
360
361 void MarkSingleton() {
362 s_singleton_ = db();
363 }
364
365 static DbHandle GetSingleton() {
366 DCHECK(s_singleton_.has_value());
367 return *s_singleton_;
368 }
369
370 void Reinitialize() {
371 const char* sql_to_initialize = R"SQLC0D3(
372 DROP TABLE IF EXISTS schema_versions;
373 DROP TABLE IF EXISTS packages;
374 DROP TABLE IF EXISTS activities;
375 DROP TABLE IF EXISTS app_launch_histories;
376 DROP TABLE IF EXISTS raw_traces;
377 DROP TABLE IF EXISTS prefetch_files;
378)SQLC0D3";
379 char* err_msg = nullptr;
380 int rc = sqlite3_exec(db().get(),
381 sql_to_initialize,
382 /*callback*/nullptr,
383 /*arg*/0,
384 /*out*/&err_msg);
385 if (rc != SQLITE_OK) {
386 LOG(FATAL) << "Failed to drop tables: " << err_msg ? err_msg : "nullptr";
387 }
388
389 CreateSchema();
390 LOG(INFO) << "Reinitialized database at '" << location_ << "'";
391 }
392
393 int Version() {
394 std::string query = "SELECT MAX(version) FROM schema_versions;";
395 DbStatement stmt = DbStatement::Prepare(db(), query);
396
397 int return_value = 0;
398 if (!DbQueryBuilder::SelectOnce(stmt, /*out*/return_value)) {
399 LOG(ERROR) << "Failed to query schema version";
400 return -1;
401 }
402
403 return return_value;
404 }
405
406 protected:
407 SchemaModel(DbHandle db, std::string location) : Model{db}, location_(location) {
408 }
409
410 private:
411 static std::optional<DbHandle> s_singleton_;
412
413 void CreateSchema() {
414 const char* sql_to_initialize = R"SQLC0D3(
415 CREATE TABLE schema_versions(
416 version INTEGER NOT NULL
417 );
418 INSERT INTO schema_versions VALUES(1);
419
420 CREATE TABLE packages(
421 id INTEGER NOT NULL,
422 name TEXT NOT NULL,
423 version INTEGER,
424
425 PRIMARY KEY(id)
426 );
427
428 CREATE TABLE activities(
429 id INTEGER NOT NULL,
430 name TEXT NOT NULL,
431 package_id INTEGER NOT NULL,
432
433 PRIMARY KEY(id),
434 FOREIGN KEY (package_id) REFERENCES packages (id)
435 );
436
437 CREATE TABLE app_launch_histories(
438 id INTEGER NOT NULL PRIMARY KEY,
439 activity_id INTEGER NOT NULL,
440 -- 1:Cold, 2:Warm, 3:Hot
441 temperature INTEGER CHECK (temperature IN (1, 2, 3)) NOT NULL,
442 trace_enabled INTEGER CHECK(trace_enabled in (TRUE, FALSE)) NOT NULL,
443 readahead_enabled INTEGER CHECK(trace_enabled in (TRUE, FALSE)) NOT NULL,
444 -- absolute timestamp since epoch
445 intent_started_ns INTEGER CHECK(intent_started_ns IS NULL or intent_started_ns >= 0),
446 -- absolute timestamp since epoch
447 total_time_ns INTEGER CHECK(total_time_ns IS NULL or total_time_ns >= 0),
448 -- absolute timestamp since epoch
449 report_fully_drawn_ns INTEGER CHECK(report_fully_drawn_ns IS NULL or report_fully_drawn_ns >= 0),
450
451 FOREIGN KEY (activity_id) REFERENCES activities (id)
452 );
453
454 CREATE TABLE raw_traces(
455 id INTEGER NOT NULL PRIMARY KEY,
456 history_id INTEGER NOT NULL,
457 file_path TEXT NOT NULL,
458
459 FOREIGN KEY (history_id) REFERENCES app_launch_histories (id)
460 );
461
462 CREATE TABLE prefetch_files(
463 id INTEGER NOT NULL PRIMARY KEY,
464 activity_id INTEGER NOT NULL,
465 file_path TEXT NOT NULL,
466
467 FOREIGN KEY (activity_id) REFERENCES activities (id)
468 );
469)SQLC0D3";
470
471 char* err_msg = nullptr;
472 int rc = sqlite3_exec(db().get(),
473 sql_to_initialize,
474 /*callback*/nullptr,
475 /*arg*/0,
476 /*out*/&err_msg);
477
478 if (rc != SQLITE_OK) {
479 LOG(FATAL) << "Failed to create tables: " << err_msg ? err_msg : "nullptr";
480 }
481 }
482
483 static void ErrorLogCallback(void *pArg, int iErrCode, const char *zMsg) {
484 LOG(ERROR) << "SQLite error (" << iErrCode << "): " << zMsg;
485 }
486
487 std::string location_;
488};
489
490class PackageModel : public Model {
491 protected:
492 PackageModel(DbHandle db) : Model{db} {
493 }
494
495 public:
496 static std::optional<PackageModel> SelectById(DbHandle db, int id) {
497 ScopedLockDb lock{db};
498 int original_id = id;
499
500 std::string query = "SELECT * FROM packages WHERE id = ?1 LIMIT 1;";
501 DbStatement stmt = DbStatement::Prepare(db, query, id);
502
503 PackageModel p{db};
504 if (!DbQueryBuilder::SelectOnce(stmt, p.id, p.name, p.version)) {
505 return std::nullopt;
506 }
507
508 return p;
509 }
510
511 static std::optional<PackageModel> SelectByName(DbHandle db, const char* name) {
512 ScopedLockDb lock{db};
513
514 std::string query = "SELECT * FROM packages WHERE name = ?1 LIMIT 1;";
515 DbStatement stmt = DbStatement::Prepare(db, query, name);
516
517 PackageModel p{db};
518 if (!DbQueryBuilder::SelectOnce(stmt, p.id, p.name, p.version)) {
519 return std::nullopt;
520 }
521
522 return p;
523 }
524
525 static std::optional<PackageModel> Insert(DbHandle db,
526 std::string name,
527 std::optional<int> version) {
528 const char* sql = "INSERT INTO packages (name, version) VALUES (?1, ?2);";
529
530 std::optional<int> inserted_row_id =
531 DbQueryBuilder::Insert(db, sql, name, version);
532 if (!inserted_row_id) {
533 return std::nullopt;
534 }
535
536 PackageModel p{db};
537 p.name = name;
538 p.version = version;
539 p.id = *inserted_row_id;
540
541 return p;
542 }
543
544 int id;
545 std::string name;
546 std::optional<int> version;
547};
548
549inline std::ostream& operator<<(std::ostream& os, const PackageModel& p) {
550 os << "PackageModel{id=" << p.id << ",name=" << p.name << ",";
551 os << "version=";
552 if (p.version) {
553 os << *p.version;
554 } else {
555 os << "(nullopt)";
556 }
557 os << "}";
558 return os;
559}
560
561class ActivityModel : public Model {
562 protected:
563 ActivityModel(DbHandle db) : Model{db} {
564 }
565
566 public:
567 static std::optional<ActivityModel> SelectById(DbHandle db, int id) {
568 ScopedLockDb lock{db};
569 int original_id = id;
570
571 std::string query = "SELECT * FROM activities WHERE id = ? LIMIT 1;";
572 DbStatement stmt = DbStatement::Prepare(db, query, id);
573
574 ActivityModel p{db};
575 if (!DbQueryBuilder::SelectOnce(stmt, p.id, p.name, p.package_id)) {
576 return std::nullopt;
577 }
578
579 return p;
580 }
581
582 static std::optional<ActivityModel> SelectByNameAndPackageId(DbHandle db,
583 const char* name,
584 int package_id) {
585 ScopedLockDb lock{db};
586
587 std::string query = "SELECT * FROM activities WHERE name = ? AND package_id = ? LIMIT 1;";
588 DbStatement stmt = DbStatement::Prepare(db, query, name, package_id);
589
590 ActivityModel p{db};
591 if (!DbQueryBuilder::SelectOnce(stmt, p.id, p.name, p.package_id)) {
592 return std::nullopt;
593 }
594
595 return p;
596 }
597
598 static std::optional<ActivityModel> Insert(DbHandle db,
599 std::string name,
600 int package_id) {
601 const char* sql = "INSERT INTO activities (name, package_id) VALUES (?1, ?2);";
602
603 std::optional<int> inserted_row_id =
604 DbQueryBuilder::Insert(db, sql, name, package_id);
605 if (!inserted_row_id) {
606 return std::nullopt;
607 }
608
609 ActivityModel p{db};
610 p.id = *inserted_row_id;
611 p.name = name;
612 p.package_id = package_id;
613
614 return p;
615 }
616
617 // Try to select by package_name+activity_name, otherwise insert into both tables.
618 // Package version is ignored for selects.
619 static std::optional<ActivityModel> SelectOrInsert(
620 DbHandle db,
621 std::string package_name,
622 std::optional<int> package_version,
623 std::string activity_name) {
624 std::optional<PackageModel> package = PackageModel::SelectByName(db, package_name.c_str());
625 if (!package) {
626 package = PackageModel::Insert(db, package_name, package_version);
627 DCHECK(package.has_value());
628 }
629
630 std::optional<ActivityModel> activity =
631 ActivityModel::SelectByNameAndPackageId(db,
632 activity_name.c_str(),
633 package->id);
634 if (!activity) {
635 activity = Insert(db, activity_name, package->id);
636 // XX: should we really return an optional here? This feels like it should never fail.
637 }
638
639 return activity;
640 }
641
642 int id;
643 std::string name;
644 int package_id; // PackageModel::id
645};
646
647inline std::ostream& operator<<(std::ostream& os, const ActivityModel& p) {
648 os << "ActivityModel{id=" << p.id << ",name=" << p.name << ",";
649 os << "package_id=" << p.package_id << "}";
650 return os;
651}
652
653class AppLaunchHistoryModel : public Model {
654 protected:
655 AppLaunchHistoryModel(DbHandle db) : Model{db} {
656 }
657
658 public:
659 enum class Temperature : int32_t {
660 kUninitialized = -1, // Note: Not a valid SQL value.
661 kCold = 1,
662 kWarm = 2,
663 kHot = 3,
664 };
665
666 static std::optional<AppLaunchHistoryModel> SelectById(DbHandle db, int id) {
667 ScopedLockDb lock{db};
668 int original_id = id;
669
670 std::string query = "SELECT * FROM app_launch_histories WHERE id = ? LIMIT 1;";
671 DbStatement stmt = DbStatement::Prepare(db, query, id);
672
673 AppLaunchHistoryModel p{db};
674 if (!DbQueryBuilder::SelectOnce(stmt,
675 p.id,
676 p.activity_id,
677 p.temperature,
678 p.trace_enabled,
679 p.readahead_enabled,
680 p.total_time_ns,
681 p.report_fully_drawn_ns)) {
682 return std::nullopt;
683 }
684
685 return p;
686 }
687
688 static std::optional<AppLaunchHistoryModel> Insert(DbHandle db,
689 int activity_id,
690 AppLaunchHistoryModel::Temperature temperature,
691 bool trace_enabled,
692 bool readahead_enabled,
693 std::optional<uint64_t> total_time_ns,
694 std::optional<uint64_t> report_fully_drawn_ns)
695 {
696 const char* sql = "INSERT INTO app_launch_histories (activity_id, temperature, trace_enabled, "
697 "readahead_enabled, total_time_ns, "
698 "report_fully_drawn_ns) "
699 "VALUES (?1, ?2, ?3, ?4, ?5, ?6);";
700
701 std::optional<int> inserted_row_id =
702 DbQueryBuilder::Insert(db,
703 sql,
704 activity_id,
705 temperature,
706 trace_enabled,
707 readahead_enabled,
708 total_time_ns,
709 report_fully_drawn_ns);
710 if (!inserted_row_id) {
711 return std::nullopt;
712 }
713
714 AppLaunchHistoryModel p{db};
715 p.id = *inserted_row_id;
716 p.activity_id = activity_id;
717 p.temperature = temperature;
718 p.trace_enabled = trace_enabled;
719 p.readahead_enabled = readahead_enabled;
720 p.total_time_ns = total_time_ns;
721 p.report_fully_drawn_ns = report_fully_drawn_ns;
722
723 return p;
724 }
725
726 int id;
727 int activity_id; // ActivityModel::id
728 Temperature temperature = Temperature::kUninitialized;
729 bool trace_enabled;
730 bool readahead_enabled;
731 std::optional<uint64_t> total_time_ns;
732 std::optional<uint64_t> report_fully_drawn_ns;
733};
734
735inline std::ostream& operator<<(std::ostream& os, const AppLaunchHistoryModel& p) {
736 os << "AppLaunchHistoryModel{id=" << p.id << ","
737 << "activity_id=" << p.activity_id << ","
738 << "temperature=" << static_cast<int>(p.temperature) << ","
739 << "trace_enabled=" << p.trace_enabled << ","
740 << "readahead_enabled=" << p.readahead_enabled << ","
741 << "total_time_ns=";
742 if (p.total_time_ns) {
743 os << *p.total_time_ns;
744 } else {
745 os << "(nullopt)";
746 }
747 os << ",";
748 os << "report_fully_drawn_ns=";
749 if (p.report_fully_drawn_ns) {
750 os << *p.report_fully_drawn_ns;
751 } else {
752 os << "(nullopt)";
753 }
754 os << "}";
755 return os;
756}
757
758} // namespace iorap::db
759
760#endif // IORAP_SRC_DB_MODELS_H_