Add support for Delete queries in @Query annotations
This CL adds support for runinng DELETE queries in
@Query methods.
If the @Query method has predefined number of bind args,
we create a prepared statement and re-use it. If it
has variable number of args, we recreate the query every
time it is run.
Bug: 32342709
Test: SqlParserTest, SimpleEntityReadWriteTest, QueryMethodProcessorTest
Change-Id: I8d9ad83e36e1eed4ddd5e2d714ffb6cdd9881034
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/ext/javapoet_ext.kt b/room/compiler/src/main/kotlin/com/android/support/room/ext/javapoet_ext.kt
index 17c2eb5..035d6d0 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/ext/javapoet_ext.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/ext/javapoet_ext.kt
@@ -57,6 +57,8 @@
ClassName.get("com.android.support.room", "EntityInsertionAdapter")
val DELETE_OR_UPDATE_ADAPTER : ClassName =
ClassName.get("com.android.support.room", "EntityDeletionOrUpdateAdapter")
+ val SHARED_SQLITE_STMT : ClassName =
+ ClassName.get("com.android.support.room", "SharedSQLiteStatement")
}
object AndroidTypeNames {
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/parser/ParsedQuery.kt b/room/compiler/src/main/kotlin/com/android/support/room/parser/ParsedQuery.kt
index 4c92336..5f63f1c 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/parser/ParsedQuery.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/parser/ParsedQuery.kt
@@ -37,7 +37,7 @@
data class Table(val name: String, val alias: String)
-data class ParsedQuery(val original: String, val queryType: QueryType,
+data class ParsedQuery(val original: String, val type: QueryType,
val inputs: List<TerminalNode>,
// pairs of table name and alias,
val tables: Set<Table>,
@@ -85,10 +85,10 @@
}
private fun unknownQueryTypeErrors(): List<String> {
- return if (QueryType.SUPPORTED.contains(queryType)) {
+ return if (QueryType.SUPPORTED.contains(type)) {
emptyList()
} else {
- listOf(ParserErrors.invalidQueryType(queryType))
+ listOf(ParserErrors.invalidQueryType(type))
}
}
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/parser/SqlParser.kt b/room/compiler/src/main/kotlin/com/android/support/room/parser/SqlParser.kt
index d789978..6266263 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/parser/SqlParser.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/parser/SqlParser.kt
@@ -35,7 +35,7 @@
init {
queryType = (0..statement.childCount - 1).map {
findQueryType(statement.getChild(it))
- }.filterNot { it == QueryType.UNKNOWN }.first()
+ }.filterNot { it == QueryType.UNKNOWN }.firstOrNull() ?: QueryType.UNKNOWN
statement.accept(this)
}
@@ -123,7 +123,7 @@
return QueryVisitor(input, syntaxErrors, statement).createParsedQuery()
} catch (antlrError: RuntimeException) {
return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(),
- listOf(antlrError.message ?: "unknown error while parsing $input"))
+ listOf("unknown error while parsing $input : ${antlrError.message}"))
}
}
}
@@ -138,7 +138,7 @@
INSERT;
companion object {
- val SUPPORTED = hashSetOf(SELECT)
+ val SUPPORTED = hashSetOf(SELECT, DELETE)
}
}
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/processor/QueryMethodProcessor.kt b/room/compiler/src/main/kotlin/com/android/support/room/processor/QueryMethodProcessor.kt
index fd3849c..1647150 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/processor/QueryMethodProcessor.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/processor/QueryMethodProcessor.kt
@@ -18,6 +18,7 @@
import com.android.support.room.Query
import com.android.support.room.parser.ParsedQuery
+import com.android.support.room.parser.QueryType
import com.android.support.room.parser.SqlParser
import com.android.support.room.vo.QueryMethod
import com.google.auto.common.AnnotationMirrors
@@ -56,6 +57,14 @@
context.checker.notUnbound(returnTypeName, executableElement,
ProcessorErrors.CANNOT_USE_UNBOUND_GENERICS_IN_QUERY_METHODS)
+ if (query.type == QueryType.DELETE) {
+ context.checker.check(
+ returnTypeName == TypeName.VOID || returnTypeName == TypeName.INT,
+ executableElement,
+ ProcessorErrors.DELETION_METHODS_MUST_RETURN_VOID_OR_INT
+ )
+ }
+
val resultAdapter = context.typeAdapterStore
.findQueryResultAdapter(executableType.returnType)
context.checker.check(resultAdapter != null, executableElement,
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/solver/TypeAdapterStore.kt b/room/compiler/src/main/kotlin/com/android/support/room/solver/TypeAdapterStore.kt
index 65eab2e..38510cd 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/solver/TypeAdapterStore.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/solver/TypeAdapterStore.kt
@@ -178,16 +178,19 @@
val converter = findTypeConverter(declared.typeArguments.first(),
context.COMMON_TYPES.STRING)
?: return null
- return CollectionQueryParameterAdapter(converter)
+ val bindAdapter = findColumnTypeAdapter(declared.typeArguments.first()) ?: return null
+ return CollectionQueryParameterAdapter(converter, bindAdapter)
} else if (typeMirror is ArrayType) {
val component = typeMirror.componentType
val converter = findTypeConverter(component, context.COMMON_TYPES.STRING)
?: return null
- return ArrayQueryParameterAdapter(converter)
+ val bindAdapter = findColumnTypeAdapter(component) ?: return null
+ return ArrayQueryParameterAdapter(converter, bindAdapter)
} else {
val converter = findTypeConverter(typeMirror, context.COMMON_TYPES.STRING)
?: return null
- return BasicQueryParameterAdapter(converter)
+ val bindAdapter = findColumnTypeAdapter(typeMirror) ?: return null
+ return BasicQueryParameterAdapter(converter, bindAdapter)
}
}
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/ArrayQueryParameterAdapter.kt b/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/ArrayQueryParameterAdapter.kt
index 94a5760..01aa04b 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/ArrayQueryParameterAdapter.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/ArrayQueryParameterAdapter.kt
@@ -19,14 +19,31 @@
import com.android.support.room.ext.L
import com.android.support.room.ext.T
import com.android.support.room.ext.typeName
+import com.android.support.room.processor.Context
import com.android.support.room.solver.CodeGenScope
+import com.android.support.room.solver.types.ColumnTypeAdapter
import com.android.support.room.solver.types.TypeConverter
import com.squareup.javapoet.TypeName
/**
* Binds ARRAY(T) (e.g. int[]) into String[] args of a query.
*/
-class ArrayQueryParameterAdapter(val converter : TypeConverter) : QueryParameterAdapter(true) {
+class ArrayQueryParameterAdapter(val converter : TypeConverter,
+ val bindAdapter : ColumnTypeAdapter)
+ : QueryParameterAdapter(true) {
+ override fun bindToStmt(inputVarName: String, stmtVarName: String, startIndexVarName: String,
+ scope: CodeGenScope) {
+ scope.builder().apply {
+ val itrVar = scope.getTmpVar("_item")
+ beginControlFlow("for ($T $L : $L)", converter.from.typeName(), itrVar, inputVarName)
+ .apply {
+ bindAdapter.bindToStmt(stmtVarName, startIndexVarName, itrVar, scope)
+ addStatement("$L ++", startIndexVarName)
+ }
+ endControlFlow()
+ }
+ }
+
override fun getArgCount(inputVarName: String, outputVarName : String, scope: CodeGenScope) {
scope.builder()
.addStatement("final $T $L = $L.length", TypeName.INT, outputVarName, inputVarName)
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/BasicQueryParameterAdapter.kt b/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/BasicQueryParameterAdapter.kt
index c6328af..96f1f9c 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/BasicQueryParameterAdapter.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/BasicQueryParameterAdapter.kt
@@ -18,12 +18,23 @@
import com.android.support.room.ext.L
import com.android.support.room.solver.CodeGenScope
+import com.android.support.room.solver.types.ColumnTypeAdapter
import com.android.support.room.solver.types.TypeConverter
+import com.squareup.javapoet.TypeName
/**
* Knows how to convert a query parameter into arguments
*/
-class BasicQueryParameterAdapter(val converter : TypeConverter) : QueryParameterAdapter(false) {
+class BasicQueryParameterAdapter(val converter : TypeConverter,
+ val bindAdapter : ColumnTypeAdapter)
+ : QueryParameterAdapter(false) {
+ override fun bindToStmt(inputVarName: String, stmtVarName: String, startIndexVarName: String,
+ scope: CodeGenScope) {
+ scope.builder().apply {
+ bindAdapter.bindToStmt(stmtVarName, startIndexVarName, inputVarName, scope)
+ }
+ }
+
override fun getArgCount(inputVarName: String, outputVarName : String, scope: CodeGenScope) {
throw UnsupportedOperationException("should not call getArgCount on basic adapters." +
"It is always one.")
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/CollectionQueryParameterAdapter.kt b/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/CollectionQueryParameterAdapter.kt
index b4c2c95..566e0f0 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/CollectionQueryParameterAdapter.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/CollectionQueryParameterAdapter.kt
@@ -20,13 +20,29 @@
import com.android.support.room.ext.T
import com.android.support.room.ext.typeName
import com.android.support.room.solver.CodeGenScope
+import com.android.support.room.solver.types.ColumnTypeAdapter
import com.android.support.room.solver.types.TypeConverter
import com.squareup.javapoet.TypeName
/**
* Binds Collection<T> (e.g. List<T>) into String[] query args.
*/
-class CollectionQueryParameterAdapter(val converter : TypeConverter) : QueryParameterAdapter(true) {
+class CollectionQueryParameterAdapter(val converter : TypeConverter,
+ val bindAdapter : ColumnTypeAdapter)
+ : QueryParameterAdapter(true) {
+ override fun bindToStmt(inputVarName: String, stmtVarName: String, startIndexVarName: String,
+ scope: CodeGenScope) {
+ scope.builder().apply {
+ val itrVar = scope.getTmpVar("_item")
+ beginControlFlow("for ($T $L : $L)", converter.from.typeName(), itrVar, inputVarName)
+ .apply {
+ bindAdapter.bindToStmt(stmtVarName, startIndexVarName, itrVar, scope)
+ addStatement("$L ++", startIndexVarName)
+ }
+ endControlFlow()
+ }
+ }
+
override fun getArgCount(inputVarName: String, outputVarName : String, scope: CodeGenScope) {
scope.builder()
.addStatement("final $T $L = $L.size()", TypeName.INT, outputVarName, inputVarName)
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/QueryParameterAdapter.kt b/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/QueryParameterAdapter.kt
index 90e2248..66b7e48 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/QueryParameterAdapter.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/solver/query/parameter/QueryParameterAdapter.kt
@@ -30,6 +30,12 @@
scope: CodeGenScope)
/**
+ * Must bind the value into the statement at the given index.
+ */
+ abstract fun bindToStmt(inputVarName: String, stmtVarName: String, startIndexVarName: String,
+ scope: CodeGenScope)
+
+ /**
* Should declare and set the given value with the count
*/
abstract fun getArgCount(inputVarName: String, outputVarName : String, scope : CodeGenScope)
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/vo/Field.kt b/room/compiler/src/main/kotlin/com/android/support/room/vo/Field.kt
index 989fec4..02618e4 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/vo/Field.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/vo/Field.kt
@@ -54,7 +54,14 @@
}
val getterNameWithVariations by lazy {
- nameWithVariations.map { "get${it.capitalize()}" }
+ nameWithVariations.map { "get${it.capitalize()}" } +
+ if (typeName == TypeName.BOOLEAN || typeName == TypeName.BOOLEAN.box()) {
+ nameWithVariations.flatMap {
+ listOf("is${it.capitalize()}", "has${it.capitalize()}")
+ }
+ } else {
+ emptyList()
+ }
}
val setterNameWithVariations by lazy {
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/vo/QueryMethod.kt b/room/compiler/src/main/kotlin/com/android/support/room/vo/QueryMethod.kt
index 3f00fc7..aae8112 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/vo/QueryMethod.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/vo/QueryMethod.kt
@@ -16,8 +16,10 @@
package com.android.support.room.vo
+import com.android.support.room.ext.typeName
import com.android.support.room.parser.ParsedQuery
import com.android.support.room.solver.query.result.QueryResultAdapter
+import com.squareup.javapoet.TypeName
import javax.lang.model.element.ExecutableElement
import javax.lang.model.type.TypeMirror
@@ -42,4 +44,8 @@
}
}
}
+
+ val returnsValue by lazy {
+ returnType.typeName() != TypeName.VOID
+ }
}
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/writer/DaoWriter.kt b/room/compiler/src/main/kotlin/com/android/support/room/writer/DaoWriter.kt
index 644bde2..68362bb 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/writer/DaoWriter.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/writer/DaoWriter.kt
@@ -20,6 +20,7 @@
import com.android.support.room.ext.L
import com.android.support.room.ext.N
import com.android.support.room.ext.RoomTypeNames
+import com.android.support.room.ext.SupportDbTypeNames
import com.android.support.room.ext.T
import com.android.support.room.parser.QueryType
import com.android.support.room.solver.CodeGenScope
@@ -37,9 +38,7 @@
import stripNonJava
import javax.lang.model.element.ElementKind
import javax.lang.model.element.ExecutableElement
-import javax.lang.model.element.Modifier.FINAL
-import javax.lang.model.element.Modifier.PRIVATE
-import javax.lang.model.element.Modifier.PUBLIC
+import javax.lang.model.element.Modifier.*
/**
* Creates the implementation for a class annotated with Dao.
@@ -55,8 +54,22 @@
val builder = TypeSpec.classBuilder(dao.implTypeName)
val scope = CodeGenScope()
+ /**
+ * if delete methods wants to return modified rows, we need prepared query.
+ * in that case, if args are dynamic, we cannot re-use the query, if not, we should re-use
+ * it. this requires more work but creates good performance.
+ */
+ val groupedDeletions = dao.queryMethods
+ .filter { it.query.type == QueryType.DELETE }
+ .groupBy { it.parameters.any { it.queryParamAdapter?.isMultiple ?: true } }
+ // delete queries that can be prepared ahead of time
+ val preparedDeleteQueries = groupedDeletions[false] ?: emptyList()
+ // delete queries that must be rebuild every single time
+ val oneOffDeleteQueries = groupedDeletions[true] ?: emptyList()
val shortcutMethods = groupAndCreateInsertionMethods(scope) +
- groupAndCreateDeletionMethods(scope)
+ groupAndCreateDeletionMethods(scope) +
+ createPreparedDeleteQueries(preparedDeleteQueries, scope)
+
builder.apply {
addModifiers(PUBLIC)
if (dao.element.kind == ElementKind.INTERFACE) {
@@ -76,15 +89,55 @@
}
}
- dao.queryMethods.filter { it.query.queryType == QueryType.SELECT }.forEach { method ->
- builder.addMethod(createSelectMethod(method))
+ dao.queryMethods.filter { it.query.type == QueryType.SELECT }.forEach { method ->
+ addMethod(createSelectMethod(method))
+ }
+ oneOffDeleteQueries.forEach {
+ addMethod(createDeleteQueryMethod(it))
}
}
return builder.build()
}
+ private fun createPreparedDeleteQueries(preparedDeleteQueries: List<QueryMethod>,
+ scope: CodeGenScope): List<PreparedStmtQuery> {
+ return preparedDeleteQueries.map { method ->
+ val fieldName = scope.getTmpVar("_preparedStmtOf${method.name.capitalize()}")
+ val fieldSpec = FieldSpec.builder(RoomTypeNames.SHARED_SQLITE_STMT, fieldName,
+ PRIVATE, FINAL).build()
+ val queryWriter = QueryWriter(method)
+ val fieldImpl = PreparedStatementWriter(queryWriter).createAnonymous(dbField)
+ val methodBody = createPreparedDeleteQueryMethodBody(method, fieldSpec, queryWriter)
+ PreparedStmtQuery(fieldSpec, fieldImpl, listOf(methodBody))
+ }
+ }
+
+ private fun createPreparedDeleteQueryMethodBody(method: QueryMethod,
+ preparedStmtField : FieldSpec,
+ queryWriter: QueryWriter): MethodSpec {
+ val scope = CodeGenScope()
+ val methodBuilder = overrideWithoutAnnotations(method.element).apply {
+ val stmtName = scope.getTmpVar("_stmt")
+ addStatement("final $T $L = $N.acquire()",
+ SupportDbTypeNames.SQLITE_STMT, stmtName, preparedStmtField)
+ beginControlFlow("try").apply {
+ val bindScope = scope.fork()
+ queryWriter.bindArgs(stmtName, emptyList(), bindScope)
+ addCode(bindScope.builder().build())
+ addStatement("$L$L.executeUpdateDelete()",
+ if (method.returnsValue) "return " else "",
+ stmtName)
+ }
+ nextControlFlow("finally").apply {
+ addStatement("$N.release($L)", preparedStmtField, stmtName)
+ }
+ endControlFlow()
+ }
+ return methodBuilder.build()
+ }
+
private fun createConstructor(dbParam: ParameterSpec,
- shortcutMethods: List<GroupedShortcut>): MethodSpec {
+ shortcutMethods: List<PreparedStmtQuery>): MethodSpec {
return MethodSpec.constructorBuilder().apply {
addParameter(dbParam)
addModifiers(PUBLIC)
@@ -103,11 +156,17 @@
}.build()
}
+ private fun createDeleteQueryMethod(method : QueryMethod) : MethodSpec {
+ return overrideWithoutAnnotations(method.element).apply {
+ addCode(createDeleteQueryMethodBody(method))
+ }.build()
+ }
+
/**
* Groups all insertion methods based on the insert statement they will use then creates all
* field specs, EntityInsertionAdapterWriter and actual insert methods.
*/
- private fun groupAndCreateInsertionMethods(scope : CodeGenScope): List<GroupedShortcut> {
+ private fun groupAndCreateInsertionMethods(scope : CodeGenScope): List<PreparedStmtQuery> {
return dao.insertionMethods
.groupBy {
Pair(it.entity?.typeName, it.onConflictText)
@@ -134,7 +193,7 @@
addCode(createInsertionMethodBody(method, fieldSpec))
}.build()
}
- GroupedShortcut(fieldSpec, implSpec, insertionMethodImpls)
+ PreparedStmtQuery(fieldSpec, implSpec, insertionMethodImpls)
}
}
@@ -168,7 +227,7 @@
* Groups all deletion methods based on the delete statement they will use then creates all
* field specs, EntityDeletionAdapterWriter and actual deletion methods.
*/
- private fun groupAndCreateDeletionMethods(scope : CodeGenScope): List<GroupedShortcut> {
+ private fun groupAndCreateDeletionMethods(scope : CodeGenScope): List<PreparedStmtQuery> {
return dao.deletionMethods
.groupBy {
it.entity?.typeName
@@ -194,7 +253,7 @@
addCode(createDeletionMethodBody(method, fieldSpec))
}.build()
}
- GroupedShortcut(fieldSpec, implSpec, deletionMethodImpls)
+ PreparedStmtQuery(fieldSpec, implSpec, deletionMethodImpls)
}
}
@@ -232,12 +291,32 @@
}.build()
}
+ /**
+ * @Query with delete action
+ */
+ private fun createDeleteQueryMethodBody(method: QueryMethod): CodeBlock {
+ val queryWriter = QueryWriter(method)
+ val scope = CodeGenScope()
+ val sqlVar = scope.getTmpVar("_sql")
+ val stmtVar = scope.getTmpVar("_stmt")
+ queryWriter.prepareQuery(sqlVar, scope)
+ scope.builder().apply {
+ addStatement("$T $L = $N.compileStatement($L)",
+ SupportDbTypeNames.SQLITE_STMT, stmtVar, dbField, sqlVar)
+ queryWriter.bindArgs(stmtVar, emptyList(), scope)
+ addStatement("$L$L.executeUpdateDelete()",
+ if (method.returnsValue) "return " else "",
+ stmtVar)
+ }
+ return scope.builder().build()
+ }
+
private fun createQueryMethodBody(method: QueryMethod): CodeBlock {
val queryWriter = QueryWriter(method)
val scope = CodeGenScope()
val sqlVar = scope.getTmpVar("_sql")
val argsVar = scope.getTmpVar("_args")
- queryWriter.prepareReadQuery(sqlVar, argsVar, scope)
+ queryWriter.prepareReadAndBind(sqlVar, argsVar, scope)
scope.builder().apply {
val cursorVar = scope.getTmpVar("_cursor")
val outVar = scope.getTmpVar("_result")
@@ -272,7 +351,7 @@
}
}
- data class GroupedShortcut(val field: FieldSpec?,
- val fieldImpl: TypeSpec?,
- val methodImpls: List<MethodSpec>)
+ data class PreparedStmtQuery(val field: FieldSpec?,
+ val fieldImpl: TypeSpec?,
+ val methodImpls: List<MethodSpec>)
}
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/writer/EntityInsertionAdapterWriter.kt b/room/compiler/src/main/kotlin/com/android/support/room/writer/EntityInsertionAdapterWriter.kt
index b786112..1426e7c 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/writer/EntityInsertionAdapterWriter.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/writer/EntityInsertionAdapterWriter.kt
@@ -17,7 +17,6 @@
package com.android.support.room.writer
import com.android.support.room.ext.L
-import com.android.support.room.ext.N
import com.android.support.room.ext.RoomTypeNames
import com.android.support.room.ext.S
import com.android.support.room.ext.SupportDbTypeNames
@@ -39,7 +38,7 @@
superclass(
ParameterizedTypeName.get(RoomTypeNames.INSERTION_ADAPTER, entity.typeName)
)
- addMethod(MethodSpec.methodBuilder("createInsertQuery").apply {
+ addMethod(MethodSpec.methodBuilder("createQuery").apply {
addAnnotation(Override::class.java)
returns(ClassName.get("java.lang", "String"))
addModifiers(PUBLIC)
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/writer/PreparedStatementWriter.kt b/room/compiler/src/main/kotlin/com/android/support/room/writer/PreparedStatementWriter.kt
new file mode 100644
index 0000000..0a6d11b
--- /dev/null
+++ b/room/compiler/src/main/kotlin/com/android/support/room/writer/PreparedStatementWriter.kt
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.support.room.writer
+
+import com.android.support.room.ext.L
+import com.android.support.room.ext.N
+import com.android.support.room.ext.RoomTypeNames
+import com.android.support.room.solver.CodeGenScope
+import com.squareup.javapoet.ClassName
+import com.squareup.javapoet.FieldSpec
+import com.squareup.javapoet.MethodSpec
+import com.squareup.javapoet.TypeSpec
+import javax.lang.model.element.Modifier
+
+/**
+ * Creates anonymous classes for RoomTypeNames#SHARED_SQLITE_STMT.
+ */
+class PreparedStatementWriter(val queryWriter: QueryWriter) {
+ fun createAnonymous(dbParam : FieldSpec): TypeSpec {
+ val scope = CodeGenScope()
+ @Suppress("RemoveSingleExpressionStringTemplate")
+ return TypeSpec.anonymousClassBuilder("$N", dbParam).apply {
+ superclass(RoomTypeNames.SHARED_SQLITE_STMT)
+ addMethod(MethodSpec.methodBuilder("createQuery").apply {
+ addAnnotation(Override::class.java)
+ returns(ClassName.get("java.lang", "String"))
+ addModifiers(Modifier.PUBLIC)
+ val queryName = scope.getTmpVar("_query")
+ val queryGenScope = scope.fork()
+ queryWriter.prepareQuery(queryName, queryGenScope)
+ addCode(queryGenScope.builder().build())
+ addStatement("return $L", queryName)
+ }.build())
+ }.build()
+ }
+}
diff --git a/room/compiler/src/main/kotlin/com/android/support/room/writer/QueryWriter.kt b/room/compiler/src/main/kotlin/com/android/support/room/writer/QueryWriter.kt
index fff4637..6b54658 100644
--- a/room/compiler/src/main/kotlin/com/android/support/room/writer/QueryWriter.kt
+++ b/room/compiler/src/main/kotlin/com/android/support/room/writer/QueryWriter.kt
@@ -22,6 +22,7 @@
import com.android.support.room.ext.T
import com.android.support.room.ext.arrayTypeName
import com.android.support.room.ext.typeName
+import com.android.support.room.parser.QueryType
import com.android.support.room.parser.SectionType.BIND_VAR
import com.android.support.room.parser.SectionType.NEWLINE
import com.android.support.room.parser.SectionType.TEXT
@@ -35,17 +36,16 @@
* Writes the SQL query and arguments for a QueryMethod.
*/
class QueryWriter(val queryMethod: QueryMethod) {
- fun prepareReadQuery(outSqlQueryName: String, outArgsName: String, scope: CodeGenScope) {
- scope.builder().apply {
- // mapping from parameters to the variables created for their sizes
- // it is a list of pairs instead of a map because same parameter might be bound to
- // multiple bind args
- val listSizeVars = createSqlQueryAndArgs(outSqlQueryName, outArgsName, scope)
- bindArgs(outArgsName, listSizeVars, scope)
- }
+ fun prepareReadAndBind(outSqlQueryName: String, outArgsName: String, scope: CodeGenScope) {
+ val listSizeVars = createSqlQueryAndArgs(outSqlQueryName, outArgsName, scope)
+ bindArgs(outArgsName, listSizeVars, scope)
}
- private fun createSqlQueryAndArgs(outSqlQueryName: String, outArgsName: String,
+ fun prepareQuery(outSqlQueryName: String, scope: CodeGenScope) {
+ createSqlQueryAndArgs(outSqlQueryName, null, scope)
+ }
+
+ private fun createSqlQueryAndArgs(outSqlQueryName: String, outArgsName: String?,
scope: CodeGenScope): List<Pair<QueryParameter, String>> {
val listSizeVars = arrayListOf<Pair<QueryParameter, String>>()
val varargParams = queryMethod.parameters
@@ -87,30 +87,34 @@
addStatement("$T $L = $L.toString()", String::class.typeName(),
outSqlQueryName, stringBuilderVar)
- val argCount = scope.getTmpVar("_argCount")
-
- addStatement("final $T $L = $L$L", TypeName.INT, argCount, knownQueryArgsCount,
- listSizeVars.joinToString("") { " + ${it.second}" })
- addStatement("$T $L = new String[$L]",
- String::class.arrayTypeName(), outArgsName, argCount)
+ if (outArgsName != null) {
+ val argCount = scope.getTmpVar("_argCount")
+ addStatement("final $T $L = $L$L", TypeName.INT, argCount, knownQueryArgsCount,
+ listSizeVars.joinToString("") { " + ${it.second}" })
+ addStatement("$T $L = new String[$L]",
+ String::class.arrayTypeName(), outArgsName, argCount)
+ }
} else {
addStatement("$T $L = $S", String::class.typeName(),
outSqlQueryName, queryMethod.query.queryWithReplacedBindParams)
- addStatement("$T $L = new String[$L]",
- String::class.arrayTypeName(), outArgsName, knownQueryArgsCount)
+ if (outArgsName != null) {
+ addStatement("$T $L = new String[$L]",
+ String::class.arrayTypeName(), outArgsName, knownQueryArgsCount)
+ }
}
}
return listSizeVars
}
- private fun bindArgs(outArgsName: String, listSizeVars : List<Pair<QueryParameter, String>>
- ,scope: CodeGenScope) {
+ fun bindArgs(outArgsName: String, listSizeVars : List<Pair<QueryParameter, String>>,
+ scope: CodeGenScope) {
if (queryMethod.parameters.isEmpty()) {
return
}
scope.builder().apply {
val argIndex = scope.getTmpVar("_argIndex")
- addStatement("$T $L = 0", TypeName.INT, argIndex)
+ val startIndex = if (queryMethod.query.type == QueryType.SELECT) 0 else 1
+ addStatement("$T $L = $L", TypeName.INT, argIndex, startIndex)
// # of bindings with 1 placeholder
var constInputs = 0
// variable names for size of the bindings that have multiple args
@@ -118,19 +122,25 @@
queryMethod.sectionToParamMapping.forEach { pair ->
// reset the argIndex to the correct start index
if (constInputs > 0 || varInputs.isNotEmpty()) {
- addStatement("$L = $L$L$L", argIndex,
+ addStatement("$L = $L$L$L$L", argIndex,
+ if (startIndex > 0) "$startIndex + " else "",
if (constInputs > 0) constInputs else "",
if (constInputs > 0 && varInputs.isNotEmpty()) " + " else "",
varInputs.joinToString(" + "))
}
val param = pair.second
param?.let {
- param.queryParamAdapter?.convert(param.name, outArgsName, argIndex, scope)
+ if (queryMethod.query.type == QueryType.SELECT) {
+ param.queryParamAdapter?.convert(param.name, outArgsName, argIndex, scope)
+ } else {
+ param.queryParamAdapter?.bindToStmt(param.name, outArgsName, argIndex,
+ scope)
+ }
}
// add these to the list so that we can use them to calculate the next count.
val sizeVar = listSizeVars.firstOrNull { it.first == param }
if (sizeVar == null) {
- constInputs++
+ constInputs ++
} else {
varInputs.add(sizeVar.second)
}
diff --git a/room/compiler/src/test/data/daoWriter/input/DeletionDao.java b/room/compiler/src/test/data/daoWriter/input/DeletionDao.java
index 0dc5f16..64e91b5 100644
--- a/room/compiler/src/test/data/daoWriter/input/DeletionDao.java
+++ b/room/compiler/src/test/data/daoWriter/input/DeletionDao.java
@@ -36,4 +36,10 @@
@Delete
int multiPKey(MultiPKeyEntity entity);
+
+ @Query("DELETE FROM user where uid = ?")
+ int deleteByUid(int uid);
+
+ @Query("DELETE FROM user where uid IN(?)")
+ int deleteByUidList(int... uid);
}
diff --git a/room/compiler/src/test/data/daoWriter/output/DeletionDao.java b/room/compiler/src/test/data/daoWriter/output/DeletionDao.java
index 0c2aac5..1fd44fb 100644
--- a/room/compiler/src/test/data/daoWriter/output/DeletionDao.java
+++ b/room/compiler/src/test/data/daoWriter/output/DeletionDao.java
@@ -3,8 +3,11 @@
import com.android.support.db.SupportSQLiteStatement;
import com.android.support.room.EntityDeletionOrUpdateAdapter;
import com.android.support.room.RoomDatabase;
+import com.android.support.room.SharedSQLiteStatement;
+import com.android.support.room.util.StringUtil;
import java.lang.Override;
import java.lang.String;
+import java.lang.StringBuilder;
import java.util.List;
public class DeletionDao_Impl implements DeletionDao {
@@ -14,6 +17,8 @@
private final EntityDeletionOrUpdateAdapter __deletionAdapterOfMultiPKeyEntity;
+ private final SharedSQLiteStatement _preparedStmtOfDeleteByUid;
+
public DeletionDao_Impl(RoomDatabase __db) {
this.__db = __db;
this.__deletionAdapterOfUser = new EntityDeletionOrUpdateAdapter<User>(__db) {
@@ -47,6 +52,13 @@
}
}
};
+ this._preparedStmtOfDeleteByUid = new SharedSQLiteStatement(__db) {
+ @Override
+ public String createQuery() {
+ String _query = "DELETE FROM user where uid = ?";
+ return _query;
+ }
+ };
}
@Override
@@ -135,4 +147,33 @@
__db.endTransaction();
}
}
-}
\ No newline at end of file
+
+ @Override
+ public int deleteByUid(int uid) {
+ final SupportSQLiteStatement _stmt = _preparedStmtOfDeleteByUid.acquire();
+ try {
+ int _argIndex = 1;
+ _stmt.bindLong(_argIndex, uid);
+ return _stmt.executeUpdateDelete();
+ } finally {
+ _preparedStmtOfDeleteByUid.release(_stmt);
+ }
+ }
+
+ @Override
+ public int deleteByUidList(int... uid) {
+ StringBuilder _stringBuilder = StringUtil.newStringBuilder();
+ _stringBuilder.append("DELETE FROM user where uid IN(");
+ final int _inputSize = uid.length;
+ StringUtil.appendPlaceholders(_stringBuilder, _inputSize);
+ _stringBuilder.append(")");
+ String _sql = _stringBuilder.toString();
+ SupportSQLiteStatement _stmt = __db.compileStatement(_sql);
+ int _argIndex = 1;
+ for (int _item : uid) {
+ _stmt.bindLong(_argIndex, _item);
+ _argIndex ++;
+ }
+ return _stmt.executeUpdateDelete();
+ }
+}
diff --git a/room/compiler/src/test/data/daoWriter/output/WriterDao.java b/room/compiler/src/test/data/daoWriter/output/WriterDao.java
index 34f2729..bebfe23 100644
--- a/room/compiler/src/test/data/daoWriter/output/WriterDao.java
+++ b/room/compiler/src/test/data/daoWriter/output/WriterDao.java
@@ -35,7 +35,7 @@
this.__db = __db;
this.__insertionAdapterOfUser = new EntityInsertionAdapter<User>(__db) {
@Override
- public String createInsertQuery() {
+ public String createQuery() {
return "INSERT OR ABORT INTO `User`(`uid`,`name`,`lastName`,`ageColumn`) VALUES"
+ " (?,?,?,?)";
}
@@ -58,7 +58,7 @@
};
this.__insertionAdapterOfUser_1 = new EntityInsertionAdapter<User>(__db) {
@Override
- public String createInsertQuery() {
+ public String createQuery() {
return "INSERT OR REPLACE INTO `User`(`uid`,`name`,`lastName`,`ageColumn`) VALUES"
+ " (?,?,?,?)";
}
diff --git a/room/compiler/src/test/kotlin/com/android/support/room/parser/SqlParserTest.kt b/room/compiler/src/test/kotlin/com/android/support/room/parser/SqlParserTest.kt
index 49b2a14..1e29b93 100644
--- a/room/compiler/src/test/kotlin/com/android/support/room/parser/SqlParserTest.kt
+++ b/room/compiler/src/test/kotlin/com/android/support/room/parser/SqlParserTest.kt
@@ -38,8 +38,15 @@
@Test
fun deleteQuery() {
- assertErrors("DELETE FROM users where id > 3",
- ParserErrors.invalidQueryType(QueryType.DELETE))
+ val parsed = SqlParser.parse("DELETE FROM users where id > 3")
+ assertThat(parsed.errors, `is`(emptyList()))
+ assertThat(parsed.type, `is`(QueryType.DELETE))
+ }
+
+ @Test
+ fun badDeleteQuery() {
+ assertErrors("delete from user where mAge >= :min && mAge <= :max",
+ "no viable alternative at input 'delete from user where mAge >= :min &&'")
}
@Test
diff --git a/room/compiler/src/test/kotlin/com/android/support/room/processor/QueryMethodProcessorTest.kt b/room/compiler/src/test/kotlin/com/android/support/room/processor/QueryMethodProcessorTest.kt
index c92dd24..6dd0d6b 100644
--- a/room/compiler/src/test/kotlin/com/android/support/room/processor/QueryMethodProcessorTest.kt
+++ b/room/compiler/src/test/kotlin/com/android/support/room/processor/QueryMethodProcessorTest.kt
@@ -18,6 +18,7 @@
import com.android.support.room.Dao
import com.android.support.room.Query
+import com.android.support.room.ext.hasAnnotation
import com.android.support.room.ext.typeName
import com.android.support.room.testing.TestInvocation
import com.android.support.room.testing.TestProcessor
@@ -287,6 +288,31 @@
}.compilesWithoutError()
}
+ @Test
+ fun testReadDeleteWithBadReturnType() {
+ singleQueryMethod(
+ """
+ @Query("DELETE FROM users where id = ?")
+ abstract public float foo(int id);
+ """) { parsedQuery, invocation ->
+ }.failsToCompile().withErrorContaining(
+ ProcessorErrors.DELETION_METHODS_MUST_RETURN_VOID_OR_INT
+ )
+ }
+
+ @Test
+ fun testSimpleDelete() {
+ singleQueryMethod(
+ """
+ @Query("DELETE FROM users where id = ?")
+ abstract public int foo(int id);
+ """) { parsedQuery, invocation ->
+ assertThat(parsedQuery.name, `is`("foo"))
+ assertThat(parsedQuery.parameters.size, `is`(1))
+ assertThat(parsedQuery.returnType.typeName(), `is`(TypeName.INT))
+ }.compilesWithoutError()
+ }
+
fun singleQueryMethod(vararg input: String,
handler: (QueryMethod, TestInvocation) -> Unit):
CompileTester {
@@ -304,8 +330,7 @@
invocation.processingEnv.elementUtils
.getAllMembers(MoreElements.asType(it))
.filter {
- MoreElements.isAnnotationPresent(it,
- Query::class.java)
+ it.hasAnnotation(Query::class)
}
)
}.filter { it.second.isNotEmpty() }.first()
diff --git a/room/compiler/src/test/kotlin/com/android/support/room/solver/query/QueryWriterTest.kt b/room/compiler/src/test/kotlin/com/android/support/room/solver/query/QueryWriterTest.kt
index fe6072c..4a2f4a2 100644
--- a/room/compiler/src/test/kotlin/com/android/support/room/solver/query/QueryWriterTest.kt
+++ b/room/compiler/src/test/kotlin/com/android/support/room/solver/query/QueryWriterTest.kt
@@ -55,7 +55,7 @@
abstract java.util.List<Integer> selectAllIds();
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`(
"""
java.lang.String _sql = "SELECT id FROM users";
@@ -71,7 +71,7 @@
abstract java.util.List<Integer> selectAllIds(String name);
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`(
"""
java.lang.String _sql = "SELECT id FROM users WHERE name LIKE ?";
@@ -89,7 +89,7 @@
abstract java.util.List<Integer> selectAllIds(int id1, int id2);
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`(
"""
java.lang.String _sql = "SELECT id FROM users WHERE id IN(?,?)";
@@ -109,7 +109,7 @@
abstract java.util.List<Integer> selectAllIds(long time, int... ids);
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`(
"""
java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder();
@@ -158,7 +158,7 @@
abstract List<Integer> selectAllIds(long time, List<Integer> ids);
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`(collectionOut))
}.compilesWithoutError()
}
@@ -170,7 +170,7 @@
abstract List<Integer> selectAllIds(long time, Set<Integer> ids);
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`(collectionOut))
}.compilesWithoutError()
}
@@ -182,7 +182,7 @@
abstract List<Integer> selectAllIds(int age);
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`("""
java.lang.String _sql = "SELECT id FROM users WHERE age > ? OR bage > ?";
java.lang.String[] _args = new String[2];
@@ -201,7 +201,7 @@
abstract List<Integer> selectAllIds(int age, int... ages);
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`("""
java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder();
_stringBuilder.append("SELECT id FROM users WHERE age > ");
@@ -235,7 +235,7 @@
abstract List<Integer> selectAllIds(int age, int... ages);
""") { writer ->
val scope = CodeGenScope()
- writer.prepareReadQuery("_sql", "_args", scope)
+ writer.prepareReadAndBind("_sql", "_args", scope)
assertThat(scope.generate().trim(), `is`("""
java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder();
_stringBuilder.append("SELECT id FROM users WHERE age IN (");
diff --git a/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/dao/UserDao.java b/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/dao/UserDao.java
index 1295d8f..1298735 100644
--- a/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/dao/UserDao.java
+++ b/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/dao/UserDao.java
@@ -49,4 +49,16 @@
@Insert
void insertAll(User[] users);
+
+ @Query("select * from user where mAdmin = ?")
+ List<User> findByAdmin(boolean isAdmin);
+
+ @Query("delete from user where mAge > ?")
+ int deleteAgeGreaterThan(int age);
+
+ @Query("delete from user where mId IN(?)")
+ int deleteByUids(int... uids);
+
+ @Query("delete from user where mAge >= :min AND mAge <= :max")
+ int deleteByAgeRange(int min, int max);
}
diff --git a/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/test/SimpleEntityReadWriteTest.java b/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/test/SimpleEntityReadWriteTest.java
index 74641bd..dbcfa13 100644
--- a/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/test/SimpleEntityReadWriteTest.java
+++ b/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/test/SimpleEntityReadWriteTest.java
@@ -39,12 +39,15 @@
import org.junit.Test;
import org.junit.runner.RunWith;
+import java.util.Arrays;
import java.util.List;
+@SuppressWarnings("ArraysAsListWithZeroOrOneArgument")
@SmallTest
@RunWith(AndroidJUnit4.class)
public class SimpleEntityReadWriteTest {
private UserDao mUserDao;
+
@Before
public void createDb() {
Context context = InstrumentationRegistry.getTargetContext();
@@ -107,4 +110,54 @@
assertThat(deleteCount, is(2));
assertThat(mUserDao.loadByIds(3, 5, 7, 9), is(new User[]{users[1], users[2]}));
}
+
+ @Test
+ public void findByBoolean() {
+ User user1 = TestUtil.createUser(3);
+ user1.setAdmin(true);
+ User user2 = TestUtil.createUser(5);
+ user2.setAdmin(false);
+ mUserDao.insert(user1);
+ mUserDao.insert(user2);
+ assertThat(mUserDao.findByAdmin(true), is(Arrays.asList(user1)));
+ assertThat(mUserDao.findByAdmin(false), is(Arrays.asList(user2)));
+ }
+
+ @Test
+ public void deleteByAge() {
+ User user1 = TestUtil.createUser(3);
+ user1.setAge(30);
+ User user2 = TestUtil.createUser(5);
+ user2.setAge(45);
+ mUserDao.insert(user1);
+ mUserDao.insert(user2);
+ assertThat(mUserDao.deleteAgeGreaterThan(60), is(0));
+ assertThat(mUserDao.deleteAgeGreaterThan(45), is(0));
+ assertThat(mUserDao.deleteAgeGreaterThan(35), is(1));
+ assertThat(mUserDao.loadByIds(3, 5), is(new User[]{user1}));
+ }
+
+ @Test
+ public void deleteByAgeRange() {
+ User user1 = TestUtil.createUser(3);
+ user1.setAge(30);
+ User user2 = TestUtil.createUser(5);
+ user2.setAge(45);
+ mUserDao.insert(user1);
+ mUserDao.insert(user2);
+ assertThat(mUserDao.deleteByAgeRange(35, 40), is(0));
+ assertThat(mUserDao.deleteByAgeRange(25, 30), is(1));
+ assertThat(mUserDao.loadByIds(3, 5), is(new User[]{user2}));
+ }
+
+ @Test
+ public void deleteByUIds() {
+ User[] users = TestUtil.createUsersArray(3, 5, 7, 9, 11);
+ mUserDao.insertAll(users);
+ assertThat(mUserDao.deleteByUids(2, 4, 6), is(0));
+ assertThat(mUserDao.deleteByUids(3, 11), is(2));
+ assertThat(mUserDao.loadByIds(3, 5, 7, 9, 11), is(new User[]{
+ users[1], users[2], users[3]
+ }));
+ }
}
diff --git a/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/vo/User.java b/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/vo/User.java
index 2ccedb0..aa75486 100644
--- a/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/vo/User.java
+++ b/room/integration-tests/testapp/src/androidTest/java/com/android/support/room/integration/testapp/vo/User.java
@@ -26,6 +26,7 @@
private String mName;
private String mLastName;
private int mAge;
+ private boolean mAdmin;
public int getId() {
return mId;
@@ -59,6 +60,14 @@
this.mAge = age;
}
+ public boolean isAdmin() {
+ return mAdmin;
+ }
+
+ public void setAdmin(boolean admin) {
+ mAdmin = admin;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
@@ -68,6 +77,7 @@
if (mId != user.mId) return false;
if (mAge != user.mAge) return false;
+ if (mAdmin != user.mAdmin) return false;
if (mName != null ? !mName.equals(user.mName) : user.mName != null) return false;
return mLastName != null ? mLastName.equals(user.mLastName) : user.mLastName == null;
}
@@ -78,6 +88,7 @@
result = 31 * result + (mName != null ? mName.hashCode() : 0);
result = 31 * result + (mLastName != null ? mLastName.hashCode() : 0);
result = 31 * result + mAge;
+ result = 31 * result + (mAdmin ? 1 : 0);
return result;
}
}
diff --git a/room/runtime/src/main/java/com/android/support/room/EntityDeletionOrUpdateAdapter.java b/room/runtime/src/main/java/com/android/support/room/EntityDeletionOrUpdateAdapter.java
index 8842f9c..824ab33 100644
--- a/room/runtime/src/main/java/com/android/support/room/EntityDeletionOrUpdateAdapter.java
+++ b/room/runtime/src/main/java/com/android/support/room/EntityDeletionOrUpdateAdapter.java
@@ -21,7 +21,6 @@
import com.android.support.db.SupportSQLiteStatement;
import java.util.Collection;
-import java.util.concurrent.atomic.AtomicBoolean;
/**
* Implementations of this class knows how to delete or update a particular entity.
@@ -33,11 +32,7 @@
*/
@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
@SuppressWarnings({"WeakerAccess", "unused"})
-public abstract class EntityDeletionOrUpdateAdapter<T> {
- private final AtomicBoolean mStmtLock = new AtomicBoolean(false);
- private final RoomDatabase mDatabase;
- private volatile SupportSQLiteStatement mStmt;
-
+public abstract class EntityDeletionOrUpdateAdapter<T> extends SharedSQLiteStatement {
/**
* Creates a DeletionOrUpdateAdapter that can delete or update the entity type T on the given
* database.
@@ -45,7 +40,7 @@
* @param database The database to delete / update the item in.
*/
public EntityDeletionOrUpdateAdapter(RoomDatabase database) {
- mDatabase = database;
+ super(database);
}
/**
@@ -64,25 +59,6 @@
*/
protected abstract void bind(SupportSQLiteStatement statement, T entity);
- private SupportSQLiteStatement createNewStatement() {
- String query = createQuery();
- return mDatabase.compileStatement(query);
- }
-
- private SupportSQLiteStatement getStmt(boolean canUseCached) {
- final SupportSQLiteStatement stmt;
- if (canUseCached) {
- if (mStmt == null) {
- mStmt = createNewStatement();
- }
- stmt = mStmt;
- } else {
- // it is in use, create a one off statement
- stmt = createNewStatement();
- }
- return stmt;
- }
-
/**
* Deletes or updates the given entities in the database and returns the affected row count.
*
@@ -90,15 +66,12 @@
* @return The number of affected rows
*/
public final int handle(T entity) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
- final SupportSQLiteStatement stmt = getStmt(useCached);
bind(stmt, entity);
return stmt.executeUpdateDelete();
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -109,19 +82,16 @@
* @return The number of affected rows
*/
public final int handleMultiple(Collection<T> entities) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
int total = 0;
- final SupportSQLiteStatement stmt = getStmt(useCached);
for (T entity : entities) {
bind(stmt, entity);
total += stmt.executeUpdateDelete();
}
return total;
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -132,19 +102,16 @@
* @return The number of affected rows
*/
public final int handleMultiple(T[] entities) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
int total = 0;
- final SupportSQLiteStatement stmt = getStmt(useCached);
for (T entity : entities) {
bind(stmt, entity);
total += stmt.executeUpdateDelete();
}
return total;
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
}
diff --git a/room/runtime/src/main/java/com/android/support/room/EntityInsertionAdapter.java b/room/runtime/src/main/java/com/android/support/room/EntityInsertionAdapter.java
index 206668f..74dc8e5 100644
--- a/room/runtime/src/main/java/com/android/support/room/EntityInsertionAdapter.java
+++ b/room/runtime/src/main/java/com/android/support/room/EntityInsertionAdapter.java
@@ -23,7 +23,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
-import java.util.concurrent.atomic.AtomicBoolean;
/**
* Implementations of this class knows how to insert a particular entity.
@@ -35,28 +34,17 @@
*/
@SuppressWarnings({"WeakerAccess", "unused"})
@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
-public abstract class EntityInsertionAdapter<T> {
- private final AtomicBoolean mStmtLock = new AtomicBoolean(false);
- private final RoomDatabase mDatabase;
- private volatile SupportSQLiteStatement mStmt;
-
+public abstract class EntityInsertionAdapter<T> extends SharedSQLiteStatement {
/**
* Creates an InsertionAdapter that can insert the entity type T into the given database.
*
* @param database The database to insert into.
*/
public EntityInsertionAdapter(RoomDatabase database) {
- mDatabase = database;
+ super(database);
}
/**
- * Create the insertion query
- *
- * @return An SQL query that can insert instances of T.
- */
- protected abstract String createInsertQuery();
-
- /**
* Binds the entity into the given statement.
*
* @param statement The SQLite statement that prepared for the query returned from
@@ -65,40 +53,18 @@
*/
protected abstract void bind(SupportSQLiteStatement statement, T entity);
- private SupportSQLiteStatement createNewStatement() {
- String query = createInsertQuery();
- return mDatabase.compileStatement(query);
- }
-
- private SupportSQLiteStatement getStmt(boolean canUseCached) {
- final SupportSQLiteStatement stmt;
- if (canUseCached) {
- if (mStmt == null) {
- mStmt = createNewStatement();
- }
- stmt = mStmt;
- } else {
- // it is in use, create a one off statement
- stmt = createNewStatement();
- }
- return stmt;
- }
-
/**
* Inserts the entity into the database.
*
* @param entity The entity to insert
*/
public final void insert(T entity) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
- final SupportSQLiteStatement stmt = getStmt(useCached);
bind(stmt, entity);
stmt.executeInsert();
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -108,17 +74,14 @@
* @param entities Entities to insert
*/
public final void insert(T[] entities) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
- final SupportSQLiteStatement stmt = getStmt(useCached);
for (T entity : entities) {
bind(stmt, entity);
stmt.executeInsert();
}
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -128,17 +91,14 @@
* @param entities Entities to insert
*/
public final void insert(Collection<T> entities) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
- final SupportSQLiteStatement stmt = getStmt(useCached);
for (T entity : entities) {
bind(stmt, entity);
stmt.executeInsert();
}
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -149,15 +109,12 @@
* @return The SQLite row id
*/
public final long insertAndReturnId(T entity) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
- final SupportSQLiteStatement stmt = getStmt(useCached);
bind(stmt, entity);
return stmt.executeInsert();
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -168,10 +125,9 @@
* @return The SQLite row ids
*/
public final long[] insertAndReturnIdsArray(Collection<T> entities) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
final long[] result = new long[entities.size()];
- final SupportSQLiteStatement stmt = getStmt(useCached);
int index = 0;
for (T entity : entities) {
bind(stmt, entity);
@@ -180,9 +136,7 @@
}
return result;
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -193,10 +147,9 @@
* @return The SQLite row ids
*/
public final long[] insertAndReturnIdsArray(T[] entities) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
final long[] result = new long[entities.length];
- final SupportSQLiteStatement stmt = getStmt(useCached);
int index = 0;
for (T entity : entities) {
bind(stmt, entity);
@@ -205,9 +158,7 @@
}
return result;
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -218,10 +169,9 @@
* @return The SQLite row ids
*/
public final List<Long> insertAndReturnIdsList(T[] entities) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
final List<Long> result = new ArrayList<>(entities.length);
- final SupportSQLiteStatement stmt = getStmt(useCached);
int index = 0;
for (T entity : entities) {
bind(stmt, entity);
@@ -230,9 +180,7 @@
}
return result;
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
@@ -243,10 +191,9 @@
* @return The SQLite row ids
*/
public final List<Long> insertAndReturnIdsList(Collection<T> entities) {
- boolean useCached = !mStmtLock.getAndSet(true);
+ final SupportSQLiteStatement stmt = acquire();
try {
final List<Long> result = new ArrayList<>(entities.size());
- final SupportSQLiteStatement stmt = getStmt(useCached);
int index = 0;
for (T entity : entities) {
bind(stmt, entity);
@@ -255,9 +202,7 @@
}
return result;
} finally {
- if (useCached) {
- mStmtLock.set(false);
- }
+ release(stmt);
}
}
}
diff --git a/room/runtime/src/main/java/com/android/support/room/SharedSQLiteStatement.java b/room/runtime/src/main/java/com/android/support/room/SharedSQLiteStatement.java
new file mode 100644
index 0000000..8489dca
--- /dev/null
+++ b/room/runtime/src/main/java/com/android/support/room/SharedSQLiteStatement.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.support.room;
+
+import android.support.annotation.RestrictTo;
+
+import com.android.support.db.SupportSQLiteStatement;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Represents a prepared SQLite state that can be re-used multiple times.
+ * <p>
+ * This class is used by generated code. After it is used, {@code release} must be called so that
+ * it can be used by other threads.
+ * <p>
+ * To avoid re-entry even within the same thread, this class allows only 1 time access to the shared
+ * statement until it is released.
+ *
+ * @hide
+ */
+@SuppressWarnings({"WeakerAccess", "unused"})
+@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
+public abstract class SharedSQLiteStatement {
+ private final AtomicBoolean mLock = new AtomicBoolean(false);
+
+ private final RoomDatabase mDatabase;
+ private volatile SupportSQLiteStatement mStmt;
+
+ /**
+ * Creates an SQLite prepared statement that can be re-used across threads. If it is in use,
+ * it automatically creates a new one.
+ *
+ * @param database The database to create the statement in.
+ */
+ public SharedSQLiteStatement(RoomDatabase database) {
+ mDatabase = database;
+ }
+
+ /**
+ * Create the query.
+ *
+ * @return The SQL query to prepare.
+ */
+ protected abstract String createQuery();
+
+ private SupportSQLiteStatement createNewStatement() {
+ String query = createQuery();
+ return mDatabase.compileStatement(query);
+ }
+
+ private SupportSQLiteStatement getStmt(boolean canUseCached) {
+ final SupportSQLiteStatement stmt;
+ if (canUseCached) {
+ if (mStmt == null) {
+ mStmt = createNewStatement();
+ }
+ stmt = mStmt;
+ } else {
+ // it is in use, create a one off statement
+ stmt = createNewStatement();
+ }
+ return stmt;
+ }
+
+ /**
+ * Call this to get the statement. Must call {@link #release(SupportSQLiteStatement)} once done.
+ */
+ public SupportSQLiteStatement acquire() {
+ return getStmt(mLock.compareAndSet(false, true));
+ }
+
+ /**
+ * Must call this when statement will not be used anymore.
+ * @param statement The statement that was returned from acquire.
+ */
+ public void release(SupportSQLiteStatement statement) {
+ if (statement == mStmt) {
+ mLock.set(false);
+ }
+ }
+}
diff --git a/room/runtime/src/test/java/com/android/support/room/SharedSQLiteStatementTest.java b/room/runtime/src/test/java/com/android/support/room/SharedSQLiteStatementTest.java
new file mode 100644
index 0000000..3e09dd2
--- /dev/null
+++ b/room/runtime/src/test/java/com/android/support/room/SharedSQLiteStatementTest.java
@@ -0,0 +1,118 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.support.room;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.CoreMatchers.notNullValue;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import com.android.support.db.SupportSQLiteStatement;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.FutureTask;
+
+@RunWith(JUnit4.class)
+public class SharedSQLiteStatementTest {
+ private SharedSQLiteStatement mSharedStmt;
+
+ @Before
+ public void init() {
+ RoomDatabase mockDb = mock(RoomDatabase.class);
+ when(mockDb.compileStatement(anyString())).thenAnswer(new Answer<SupportSQLiteStatement>() {
+
+ @Override
+ public SupportSQLiteStatement answer(InvocationOnMock invocation) throws Throwable {
+ return mock(SupportSQLiteStatement.class);
+ }
+ });
+ mSharedStmt = new SharedSQLiteStatement(mockDb) {
+ @Override
+ protected String createQuery() {
+ return "foo";
+ }
+ };
+ }
+
+ @Test
+ public void basic() {
+ assertThat(mSharedStmt.acquire(), notNullValue());
+ }
+
+ @Test
+ public void getTwiceWithoutReleasing() {
+ SupportSQLiteStatement stmt1 = mSharedStmt.acquire();
+ SupportSQLiteStatement stmt2 = mSharedStmt.acquire();
+ assertThat(stmt1, notNullValue());
+ assertThat(stmt2, notNullValue());
+ assertThat(stmt1, is(not(stmt2)));
+ }
+
+ @Test
+ public void getTwiceWithReleasing() {
+ SupportSQLiteStatement stmt1 = mSharedStmt.acquire();
+ mSharedStmt.release(stmt1);
+ SupportSQLiteStatement stmt2 = mSharedStmt.acquire();
+ assertThat(stmt1, notNullValue());
+ assertThat(stmt1, is(stmt2));
+ }
+
+ @Test
+ public void getFromAnotherThreadWhileHolding() throws ExecutionException, InterruptedException {
+ SupportSQLiteStatement stmt1 = mSharedStmt.acquire();
+ FutureTask<SupportSQLiteStatement> task = new FutureTask<>(
+ new Callable<SupportSQLiteStatement>() {
+ @Override
+ public SupportSQLiteStatement call() throws Exception {
+ return mSharedStmt.acquire();
+ }
+ });
+ new Thread(task).run();
+ SupportSQLiteStatement stmt2 = task.get();
+ assertThat(stmt1, notNullValue());
+ assertThat(stmt2, notNullValue());
+ assertThat(stmt1, is(not(stmt2)));
+ }
+
+ @Test
+ public void getFromAnotherThreadAfterReleasing() throws ExecutionException,
+ InterruptedException {
+ SupportSQLiteStatement stmt1 = mSharedStmt.acquire();
+ mSharedStmt.release(stmt1);
+ FutureTask<SupportSQLiteStatement> task = new FutureTask<>(
+ new Callable<SupportSQLiteStatement>() {
+ @Override
+ public SupportSQLiteStatement call() throws Exception {
+ return mSharedStmt.acquire();
+ }
+ });
+ new Thread(task).run();
+ SupportSQLiteStatement stmt2 = task.get();
+ assertThat(stmt1, notNullValue());
+ assertThat(stmt1, is(stmt2));
+ }
+}