Merge "Color Picker reset support (1/3)" into tm-qpr-dev
diff --git a/src/com/android/customization/model/color/ColorOption.java b/src/com/android/customization/model/color/ColorOption.java
index 26e025d..66a3a3c 100644
--- a/src/com/android/customization/model/color/ColorOption.java
+++ b/src/com/android/customization/model/color/ColorOption.java
@@ -107,6 +107,9 @@
         if (other == null) {
             return false;
         }
+        if (mStyle != other.getStyle()) {
+            return false;
+        }
         if (mIsDefault) {
             return other.isDefault() || TextUtils.isEmpty(other.getSerializedPackages())
                     || EMPTY_JSON.equals(other.getSerializedPackages());
diff --git a/src/com/android/customization/module/ThemePickerInjector.kt b/src/com/android/customization/module/ThemePickerInjector.kt
index 09466e3..1ed9f87 100644
--- a/src/com/android/customization/module/ThemePickerInjector.kt
+++ b/src/com/android/customization/module/ThemePickerInjector.kt
@@ -47,6 +47,7 @@
 import com.android.customization.picker.clock.ui.viewmodel.ClockSettingsViewModel
 import com.android.customization.picker.color.data.repository.ColorPickerRepositoryImpl
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
 import com.android.customization.picker.color.ui.viewmodel.ColorPickerViewModel
 import com.android.customization.picker.notifications.data.repository.NotificationsRepository
 import com.android.customization.picker.notifications.domain.interactor.NotificationsInteractor
@@ -100,6 +101,7 @@
     private var notificationSectionViewModelFactory: NotificationSectionViewModel.Factory? = null
     private var colorPickerInteractor: ColorPickerInteractor? = null
     private var colorPickerViewModelFactory: ColorPickerViewModel.Factory? = null
+    private var colorPickerSnapshotRestorer: ColorPickerSnapshotRestorer? = null
     private var darkModeSnapshotRestorer: DarkModeSnapshotRestorer? = null
     private var themedIconSnapshotRestorer: ThemedIconSnapshotRestorer? = null
     private var themedIconInteractor: ThemedIconInteractor? = null
@@ -113,8 +115,7 @@
             ?: DefaultCustomizationSections(
                     getColorPickerViewModelFactory(
                         context = activity,
-                        wallpaperColorsViewModel =
-                            ViewModelProvider(activity)[WallpaperColorsViewModel::class.java],
+                        wallpaperColorsViewModel = getWallpaperColorsViewModel(),
                     ),
                     getKeyguardQuickAffordancePickerInteractor(activity),
                     getKeyguardQuickAffordancePickerViewModelFactory(activity),
@@ -190,6 +191,8 @@
             this[KEY_DARK_MODE_SNAPSHOT_RESTORER] = getDarkModeSnapshotRestorer(context)
             this[KEY_THEMED_ICON_SNAPSHOT_RESTORER] = getThemedIconSnapshotRestorer(context)
             this[KEY_APP_GRID_SNAPSHOT_RESTORER] = getGridSnapshotRestorer(context)
+            this[KEY_COLOR_PICKER_SNAPSHOT_RESTORER] =
+                getColorPickerSnapshotRestorer(context, getWallpaperColorsViewModel())
         }
     }
 
@@ -346,7 +349,12 @@
         wallpaperColorsViewModel: WallpaperColorsViewModel,
     ): ColorPickerInteractor {
         return colorPickerInteractor
-            ?: ColorPickerInteractor(ColorPickerRepositoryImpl(context, wallpaperColorsViewModel))
+            ?: ColorPickerInteractor(
+                    repository = ColorPickerRepositoryImpl(context, wallpaperColorsViewModel),
+                    snapshotRestorer = {
+                        getColorPickerSnapshotRestorer(context, wallpaperColorsViewModel)
+                    }
+                )
                 .also { colorPickerInteractor = it }
     }
 
@@ -362,6 +370,17 @@
                 .also { colorPickerViewModelFactory = it }
     }
 
+    private fun getColorPickerSnapshotRestorer(
+        context: Context,
+        wallpaperColorsViewModel: WallpaperColorsViewModel,
+    ): ColorPickerSnapshotRestorer {
+        return colorPickerSnapshotRestorer
+            ?: ColorPickerSnapshotRestorer(
+                    getColorPickerInteractor(context, wallpaperColorsViewModel)
+                )
+                .also { colorPickerSnapshotRestorer = it }
+    }
+
     fun getDarkModeSnapshotRestorer(
         context: Context,
     ): DarkModeSnapshotRestorer {
@@ -460,6 +479,8 @@
         private val KEY_THEMED_ICON_SNAPSHOT_RESTORER = KEY_DARK_MODE_SNAPSHOT_RESTORER + 1
         @JvmStatic
         private val KEY_APP_GRID_SNAPSHOT_RESTORER = KEY_THEMED_ICON_SNAPSHOT_RESTORER + 1
+        @JvmStatic
+        private val KEY_COLOR_PICKER_SNAPSHOT_RESTORER = KEY_APP_GRID_SNAPSHOT_RESTORER + 1
 
         /**
          * When this injector is overridden, this is the minimal value that should be used by
@@ -467,6 +488,6 @@
          *
          * It should always be greater than the biggest restorer key.
          */
-        @JvmStatic protected val MIN_SNAPSHOT_RESTORER_KEY = KEY_APP_GRID_SNAPSHOT_RESTORER + 1
+        @JvmStatic protected val MIN_SNAPSHOT_RESTORER_KEY = KEY_COLOR_PICKER_SNAPSHOT_RESTORER + 1
     }
 }
diff --git a/src/com/android/customization/picker/clock/ui/fragment/ClockSettingsFragment.kt b/src/com/android/customization/picker/clock/ui/fragment/ClockSettingsFragment.kt
index 976907b..2ba03bd 100644
--- a/src/com/android/customization/picker/clock/ui/fragment/ClockSettingsFragment.kt
+++ b/src/com/android/customization/picker/clock/ui/fragment/ClockSettingsFragment.kt
@@ -25,7 +25,6 @@
 import com.android.customization.module.ThemePickerInjector
 import com.android.customization.picker.clock.ui.binder.ClockSettingsBinder
 import com.android.wallpaper.R
-import com.android.wallpaper.model.WallpaperColorsViewModel
 import com.android.wallpaper.module.InjectorProvider
 import com.android.wallpaper.picker.AppbarFragment
 import com.android.wallpaper.picker.customization.ui.binder.ScreenPreviewBinder
@@ -63,7 +62,7 @@
         val injector = InjectorProvider.getInjector() as ThemePickerInjector
 
         val lockScreenView: CardView = view.requireViewById(R.id.lock_preview)
-        val colorViewModel = ViewModelProvider(activity)[WallpaperColorsViewModel::class.java]
+        val colorViewModel = injector.getWallpaperColorsViewModel()
         val displayUtils = injector.getDisplayUtils(context)
         ScreenPreviewBinder.bind(
                 activity = activity,
diff --git a/src/com/android/customization/picker/color/data/repository/ColorPickerRepository.kt b/src/com/android/customization/picker/color/data/repository/ColorPickerRepository.kt
index 0e65577..1a0f5a9 100644
--- a/src/com/android/customization/picker/color/data/repository/ColorPickerRepository.kt
+++ b/src/com/android/customization/picker/color/data/repository/ColorPickerRepository.kt
@@ -25,15 +25,13 @@
  * system color.
  */
 interface ColorPickerRepository {
-    /**
-     * The newly selected color option for overwriting the current active option during an
-     * optimistic update, the value is null when no overwriting is needed
-     */
-    val activeColorOption: Flow<ColorOptionModel?>
 
     /** List of wallpaper and preset color options on the device, categorized by Color Type */
     val colorOptions: Flow<Map<ColorType, List<ColorOptionModel>>>
 
     /** Selects a color option with optimistic update */
-    fun select(colorOptionModel: ColorOptionModel)
+    suspend fun select(colorOptionModel: ColorOptionModel)
+
+    /** Returns the current selected color option based on system settings */
+    fun getCurrentColorOption(): ColorOptionModel
 }
diff --git a/src/com/android/customization/picker/color/data/repository/ColorPickerRepositoryImpl.kt b/src/com/android/customization/picker/color/data/repository/ColorPickerRepositoryImpl.kt
index d6d5060..70382c7 100644
--- a/src/com/android/customization/picker/color/data/repository/ColorPickerRepositoryImpl.kt
+++ b/src/com/android/customization/picker/color/data/repository/ColorPickerRepositoryImpl.kt
@@ -20,6 +20,8 @@
 import android.content.Context
 import android.util.Log
 import com.android.customization.model.CustomizationManager
+import com.android.customization.model.ResourceConstants.OVERLAY_CATEGORY_COLOR
+import com.android.customization.model.ResourceConstants.OVERLAY_CATEGORY_SYSTEM_PALETTE
 import com.android.customization.model.color.ColorBundle
 import com.android.customization.model.color.ColorCustomizationManager
 import com.android.customization.model.color.ColorOption
@@ -27,11 +29,10 @@
 import com.android.customization.model.theme.OverlayManagerCompat
 import com.android.customization.picker.color.shared.model.ColorOptionModel
 import com.android.customization.picker.color.shared.model.ColorType
+import com.android.systemui.monet.Style
 import com.android.wallpaper.model.WallpaperColorsViewModel
 import kotlinx.coroutines.flow.Flow
-import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.flow.StateFlow
-import kotlinx.coroutines.flow.asStateFlow
 import kotlinx.coroutines.flow.combine
 import kotlinx.coroutines.flow.map
 import kotlinx.coroutines.suspendCancellableCoroutine
@@ -50,17 +51,11 @@
     private val colorManager: ColorCustomizationManager =
         ColorCustomizationManager.getInstance(context, OverlayManagerCompat(context))
 
-    private val _activeColorOption = MutableStateFlow<ColorOptionModel?>(null)
-    override val activeColorOption: StateFlow<ColorOptionModel?> = _activeColorOption.asStateFlow()
-
     override val colorOptions: Flow<Map<ColorType, List<ColorOptionModel>>> =
-        combine(activeColorOption, homeWallpaperColors, lockWallpaperColors) {
-                activeOption,
-                homeColors,
-                lockColors ->
-                Triple(activeOption, homeColors, lockColors)
+        combine(homeWallpaperColors, lockWallpaperColors) { homeColors, lockColors ->
+                homeColors to lockColors
             }
-            .map { (activeOption, homeColors, lockColors) ->
+            .map { (homeColors, lockColors) ->
                 suspendCancellableCoroutine { continuation ->
                     colorManager.setWallpaperColors(homeColors, lockColors)
                     colorManager.fetchOptions(
@@ -73,9 +68,8 @@
                                 options?.forEach { option ->
                                     when (option) {
                                         is ColorSeedOption ->
-                                            wallpaperColorOptions.add(option.toModel(activeOption))
-                                        is ColorBundle ->
-                                            presetColorOptions.add(option.toModel(activeOption))
+                                            wallpaperColorOptions.add(option.toModel())
+                                        is ColorBundle -> presetColorOptions.add(option.toModel())
                                     }
                                 }
                                 continuation.resumeWith(
@@ -102,33 +96,48 @@
                 }
             }
 
-    override fun select(colorOptionModel: ColorOptionModel) {
-        _activeColorOption.value = colorOptionModel
-        val colorOption: ColorOption = colorOptionModel.colorOption
-        colorManager.apply(
-            colorOption,
-            object : CustomizationManager.Callback {
-                override fun onSuccess() {
-                    _activeColorOption.value = null
-                }
+    override suspend fun select(colorOptionModel: ColorOptionModel) =
+        suspendCancellableCoroutine { continuation ->
+            colorManager.apply(
+                colorOptionModel.colorOption,
+                object : CustomizationManager.Callback {
+                    override fun onSuccess() {
+                        continuation.resumeWith(Result.success(Unit))
+                    }
 
-                override fun onError(throwable: Throwable?) {
-                    _activeColorOption.value = null
-                    Log.w(TAG, "Apply theme with error", throwable)
+                    override fun onError(throwable: Throwable?) {
+                        Log.w(TAG, "Apply theme with error", throwable)
+                        continuation.resumeWith(
+                            Result.failure(throwable ?: Throwable("Error loading theme bundles"))
+                        )
+                    }
                 }
-            }
+            )
+        }
+
+    override fun getCurrentColorOption(): ColorOptionModel {
+        val overlays = colorManager.currentOverlays
+        return ColorOptionModel(
+            colorOption =
+                // Does not matter whether ColorSeedOption or ColorBundle builder is used here
+                // because to apply the color, one just needs a generic ColorOption
+                ColorSeedOption.Builder()
+                    .addOverlayPackage(
+                        OVERLAY_CATEGORY_SYSTEM_PALETTE,
+                        overlays[OVERLAY_CATEGORY_SYSTEM_PALETTE]
+                    )
+                    .addOverlayPackage(OVERLAY_CATEGORY_COLOR, overlays[OVERLAY_CATEGORY_COLOR])
+                    .setSource(colorManager.currentColorSource)
+                    .setStyle(Style.valueOf(colorManager.currentStyle))
+                    .build(),
+            isSelected = false,
         )
     }
 
-    private fun ColorOption.toModel(activeColorOption: ColorOptionModel?): ColorOptionModel {
+    private fun ColorOption.toModel(): ColorOptionModel {
         return ColorOptionModel(
             colorOption = this,
-            isSelected =
-                if (activeColorOption != null) {
-                    isEquivalent(activeColorOption.colorOption)
-                } else {
-                    isActive(colorManager)
-                },
+            isSelected = isActive(colorManager),
         )
     }
 
diff --git a/src/com/android/customization/picker/color/data/repository/FakeColorPickerRepository.kt b/src/com/android/customization/picker/color/data/repository/FakeColorPickerRepository.kt
index d2a25bc..7dab2d8 100644
--- a/src/com/android/customization/picker/color/data/repository/FakeColorPickerRepository.kt
+++ b/src/com/android/customization/picker/color/data/repository/FakeColorPickerRepository.kt
@@ -26,118 +26,104 @@
 import kotlinx.coroutines.flow.StateFlow
 import kotlinx.coroutines.flow.asStateFlow
 
-class FakeColorPickerRepository(context: Context) : ColorPickerRepository {
-    override val activeColorOption: StateFlow<ColorOptionModel?> =
-        MutableStateFlow<ColorOptionModel?>(null)
+class FakeColorPickerRepository(private val context: Context) : ColorPickerRepository {
 
-    private val colorSeedOption0: ColorSeedOption =
-        ColorSeedOption.Builder()
-            .setLightColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setDarkColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setIndex(0)
-            .build()
-    private val colorSeedOption1: ColorSeedOption =
-        ColorSeedOption.Builder()
-            .setLightColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setDarkColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setIndex(1)
-            .build()
-    private val colorSeedOption2: ColorSeedOption =
-        ColorSeedOption.Builder()
-            .setLightColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setDarkColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setIndex(2)
-            .build()
-    private val colorSeedOption3: ColorSeedOption =
-        ColorSeedOption.Builder()
-            .setLightColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setDarkColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setIndex(3)
-            .build()
-    private val colorBundle0: ColorBundle = ColorBundle.Builder().setIndex(0).build(context)
-    private val colorBundle1: ColorBundle = ColorBundle.Builder().setIndex(1).build(context)
-    private val colorBundle2: ColorBundle = ColorBundle.Builder().setIndex(2).build(context)
-    private val colorBundle3: ColorBundle = ColorBundle.Builder().setIndex(3).build(context)
+    private lateinit var selectedColorOption: ColorOptionModel
 
     private val _colorOptions =
         MutableStateFlow(
-            mapOf(
-                ColorType.WALLPAPER_COLOR to
-                    listOf(
-                        ColorOptionModel(colorOption = colorSeedOption0, isSelected = true),
-                        ColorOptionModel(colorOption = colorSeedOption1, isSelected = false),
-                        ColorOptionModel(colorOption = colorSeedOption2, isSelected = false),
-                        ColorOptionModel(colorOption = colorSeedOption3, isSelected = false)
-                    ),
-                ColorType.BASIC_COLOR to
-                    listOf(
-                        ColorOptionModel(colorOption = colorBundle0, isSelected = false),
-                        ColorOptionModel(colorOption = colorBundle1, isSelected = false),
-                        ColorOptionModel(colorOption = colorBundle2, isSelected = false),
-                        ColorOptionModel(colorOption = colorBundle3, isSelected = false)
-                    )
+            mapOf<ColorType, List<ColorOptionModel>>(
+                ColorType.WALLPAPER_COLOR to listOf(),
+                ColorType.BASIC_COLOR to listOf()
             )
         )
     override val colorOptions: StateFlow<Map<ColorType, List<ColorOptionModel>>> =
         _colorOptions.asStateFlow()
 
-    override fun select(colorOptionModel: ColorOptionModel) {
+    init {
+        setOptions(4, 4, ColorType.WALLPAPER_COLOR, 0)
+    }
+
+    fun setOptions(
+        numWallpaperOptions: Int,
+        numPresetOptions: Int,
+        selectedColorOptionType: ColorType,
+        selectedColorOptionIndex: Int
+    ) {
+        _colorOptions.value =
+            mapOf(
+                ColorType.WALLPAPER_COLOR to
+                    buildList {
+                        repeat(times = numWallpaperOptions) { index ->
+                            val isSelected =
+                                selectedColorOptionType == ColorType.WALLPAPER_COLOR &&
+                                    selectedColorOptionIndex == index
+                            val colorOption =
+                                ColorOptionModel(
+                                    colorOption = buildWallpaperOption(index),
+                                    isSelected = isSelected,
+                                )
+                            if (isSelected) {
+                                selectedColorOption = colorOption
+                            }
+                            add(colorOption)
+                        }
+                    },
+                ColorType.BASIC_COLOR to
+                    buildList {
+                        repeat(times = numPresetOptions) { index ->
+                            val isSelected =
+                                selectedColorOptionType == ColorType.BASIC_COLOR &&
+                                    selectedColorOptionIndex == index
+                            val colorOption =
+                                ColorOptionModel(
+                                    colorOption = buildPresetOption(index),
+                                    isSelected =
+                                        selectedColorOptionType == ColorType.BASIC_COLOR &&
+                                            selectedColorOptionIndex == index,
+                                )
+                            if (isSelected) {
+                                selectedColorOption = colorOption
+                            }
+                            add(colorOption)
+                        }
+                    }
+            )
+    }
+
+    private fun buildPresetOption(index: Int): ColorBundle {
+        return ColorBundle.Builder()
+            .addOverlayPackage("TEST_PACKAGE_TYPE", "preset_color")
+            .addOverlayPackage("TEST_PACKAGE_INDEX", "$index")
+            .setIndex(index)
+            .build(context)
+    }
+
+    private fun buildWallpaperOption(index: Int): ColorSeedOption {
+        return ColorSeedOption.Builder()
+            .setLightColors(
+                intArrayOf(
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT
+                )
+            )
+            .setDarkColors(
+                intArrayOf(
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT
+                )
+            )
+            .addOverlayPackage("TEST_PACKAGE_TYPE", "wallpaper_color")
+            .addOverlayPackage("TEST_PACKAGE_INDEX", "$index")
+            .setIndex(index)
+            .build()
+    }
+
+    override suspend fun select(colorOptionModel: ColorOptionModel) {
         val colorOptions = _colorOptions.value
         val wallpaperColorOptions = colorOptions[ColorType.WALLPAPER_COLOR]!!
         val newWallpaperColorOptions = buildList {
@@ -168,6 +154,8 @@
             )
     }
 
+    override fun getCurrentColorOption(): ColorOptionModel = selectedColorOption
+
     private fun ColorOptionModel.testEquals(other: Any?): Boolean {
         if (other == null) {
             return false
diff --git a/src/com/android/customization/picker/color/domain/interactor/ColorPickerInteractor.kt b/src/com/android/customization/picker/color/domain/interactor/ColorPickerInteractor.kt
index ce453c3..a932067 100644
--- a/src/com/android/customization/picker/color/domain/interactor/ColorPickerInteractor.kt
+++ b/src/com/android/customization/picker/color/domain/interactor/ColorPickerInteractor.kt
@@ -16,17 +16,57 @@
  */
 package com.android.customization.picker.color.domain.interactor
 
+import androidx.annotation.VisibleForTesting
 import com.android.customization.picker.color.data.repository.ColorPickerRepository
 import com.android.customization.picker.color.shared.model.ColorOptionModel
+import javax.inject.Provider
+import kotlinx.coroutines.flow.MutableStateFlow
+import kotlinx.coroutines.flow.combine
 
 /** Single entry-point for all application state and business logic related to system color. */
 class ColorPickerInteractor(
     private val repository: ColorPickerRepository,
+    private val snapshotRestorer: Provider<ColorPickerSnapshotRestorer>,
 ) {
-    /** List of wallpaper and preset color options on the device, categorized by Color Type */
-    val colorOptions = repository.colorOptions
+    /**
+     * The newly selected color option for overwriting the current active option during an
+     * optimistic update, the value is set to null when update fails
+     */
+    @VisibleForTesting private val activeColorOption = MutableStateFlow<ColorOptionModel?>(null)
 
-    fun select(colorOptionModel: ColorOptionModel) {
-        repository.select(colorOptionModel)
+    /** List of wallpaper and preset color options on the device, categorized by Color Type */
+    val colorOptions =
+        combine(repository.colorOptions, activeColorOption) { colorOptions, activeOption ->
+            colorOptions
+                .map { colorTypeEntry ->
+                    colorTypeEntry.key to
+                        colorTypeEntry.value.map { colorOptionModel ->
+                            val isSelected =
+                                if (activeOption != null) {
+                                    colorOptionModel.colorOption.isEquivalent(
+                                        activeOption.colorOption
+                                    )
+                                } else {
+                                    colorOptionModel.isSelected
+                                }
+                            ColorOptionModel(
+                                colorOption = colorOptionModel.colorOption,
+                                isSelected = isSelected
+                            )
+                        }
+                }
+                .toMap()
+        }
+
+    suspend fun select(colorOptionModel: ColorOptionModel) {
+        activeColorOption.value = colorOptionModel
+        try {
+            repository.select(colorOptionModel)
+            snapshotRestorer.get().storeSnapshot(colorOptionModel)
+        } catch (e: Exception) {
+            activeColorOption.value = null
+        }
     }
+
+    fun getCurrentColorOption(): ColorOptionModel = repository.getCurrentColorOption()
 }
diff --git a/src/com/android/customization/picker/color/domain/interactor/ColorPickerSnapshotRestorer.kt b/src/com/android/customization/picker/color/domain/interactor/ColorPickerSnapshotRestorer.kt
new file mode 100644
index 0000000..1635e01
--- /dev/null
+++ b/src/com/android/customization/picker/color/domain/interactor/ColorPickerSnapshotRestorer.kt
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package com.android.customization.picker.color.domain.interactor
+
+import android.util.Log
+import com.android.customization.picker.color.shared.model.ColorOptionModel
+import com.android.wallpaper.picker.undo.domain.interactor.SnapshotRestorer
+import com.android.wallpaper.picker.undo.domain.interactor.SnapshotStore
+import com.android.wallpaper.picker.undo.shared.model.RestorableSnapshot
+
+/** Handles state restoration for the color picker system. */
+class ColorPickerSnapshotRestorer(
+    private val interactor: ColorPickerInteractor,
+) : SnapshotRestorer {
+
+    private lateinit var snapshotStore: SnapshotStore
+    private var originalOption: ColorOptionModel? = null
+
+    fun storeSnapshot(colorOptionModel: ColorOptionModel) {
+        snapshotStore.store(snapshot(colorOptionModel))
+    }
+
+    override suspend fun setUpSnapshotRestorer(
+        store: SnapshotStore,
+    ): RestorableSnapshot {
+        snapshotStore = store
+        originalOption = interactor.getCurrentColorOption()
+        return snapshot(originalOption)
+    }
+
+    override suspend fun restoreToSnapshot(snapshot: RestorableSnapshot) {
+        val optionPackagesFromSnapshot: String? = snapshot.args[KEY_COLOR_OPTION_PACKAGES]
+        originalOption?.let { optionToRestore ->
+            if (
+                optionToRestore.colorOption.serializedPackages != optionPackagesFromSnapshot ||
+                    optionToRestore.colorOption.style.toString() !=
+                        snapshot.args[KEY_COLOR_OPTION_STYLE]
+            ) {
+                Log.wtf(
+                    TAG,
+                    """ Original packages does not match snapshot packages to restore to. The 
+                        | current implementation doesn't support undo, only a reset back to the 
+                        | original color option.""".trimMargin(),
+                )
+            }
+
+            interactor.select(optionToRestore)
+        }
+    }
+
+    private fun snapshot(colorOptionModel: ColorOptionModel? = null): RestorableSnapshot {
+        val snapshotMap = mutableMapOf<String, String>()
+        colorOptionModel?.let {
+            snapshotMap[KEY_COLOR_OPTION_PACKAGES] = colorOptionModel.colorOption.serializedPackages
+            snapshotMap[KEY_COLOR_OPTION_STYLE] = colorOptionModel.colorOption.style.toString()
+        }
+        return RestorableSnapshot(snapshotMap)
+    }
+
+    companion object {
+        private const val TAG = "ColorPickerSnapshotRestorer"
+        private const val KEY_COLOR_OPTION_PACKAGES = "color_packages"
+        private const val KEY_COLOR_OPTION_STYLE = "color_style"
+    }
+}
diff --git a/src/com/android/customization/picker/color/ui/fragment/ColorPickerFragment.kt b/src/com/android/customization/picker/color/ui/fragment/ColorPickerFragment.kt
index 416faa6..fa7a344 100644
--- a/src/com/android/customization/picker/color/ui/fragment/ColorPickerFragment.kt
+++ b/src/com/android/customization/picker/color/ui/fragment/ColorPickerFragment.kt
@@ -27,7 +27,6 @@
 import com.android.customization.module.ThemePickerInjector
 import com.android.customization.picker.color.ui.binder.ColorPickerBinder
 import com.android.wallpaper.R
-import com.android.wallpaper.model.WallpaperColorsViewModel
 import com.android.wallpaper.module.InjectorProvider
 import com.android.wallpaper.picker.AppbarFragment
 import com.android.wallpaper.picker.customization.ui.binder.ScreenPreviewBinder
@@ -63,7 +62,7 @@
         val homeScreenView: CardView = view.requireViewById(R.id.home_preview)
         val wallpaperInfoFactory = injector.getCurrentWallpaperInfoFactory(requireContext())
         val displayUtils: DisplayUtils = injector.getDisplayUtils(requireContext())
-        val wcViewModel = ViewModelProvider(requireActivity())[WallpaperColorsViewModel::class.java]
+        val wcViewModel = injector.getWallpaperColorsViewModel()
         ColorPickerBinder.bind(
             view = view,
             viewModel =
diff --git a/src/com/android/customization/picker/color/ui/viewmodel/ColorPickerViewModel.kt b/src/com/android/customization/picker/color/ui/viewmodel/ColorPickerViewModel.kt
index 7eb5488..5e1e542 100644
--- a/src/com/android/customization/picker/color/ui/viewmodel/ColorPickerViewModel.kt
+++ b/src/com/android/customization/picker/color/ui/viewmodel/ColorPickerViewModel.kt
@@ -19,6 +19,7 @@
 import android.content.Context
 import androidx.lifecycle.ViewModel
 import androidx.lifecycle.ViewModelProvider
+import androidx.lifecycle.viewModelScope
 import com.android.customization.model.color.ColorBundle
 import com.android.customization.model.color.ColorSeedOption
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
@@ -30,6 +31,7 @@
 import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.flow.combine
 import kotlinx.coroutines.flow.map
+import kotlinx.coroutines.launch
 
 /** Models UI state for a color picker experience. */
 class ColorPickerViewModel
@@ -90,7 +92,7 @@
                         if (colorOptionModel.isSelected) {
                             null
                         } else {
-                            { interactor.select(colorOptionModel) }
+                            { viewModelScope.launch { interactor.select(colorOptionModel) } }
                         }
                 )
             }
@@ -115,7 +117,7 @@
                         if (colorOptionModel.isSelected) {
                             null
                         } else {
-                            { interactor.select(colorOptionModel) }
+                            { viewModelScope.launch { interactor.select(colorOptionModel) } }
                         },
                 )
             }
diff --git a/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerInteractorTest.kt b/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerInteractorTest.kt
index 81ef55f..885d5a9 100644
--- a/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerInteractorTest.kt
+++ b/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerInteractorTest.kt
@@ -21,10 +21,13 @@
 import androidx.test.platform.app.InstrumentationRegistry
 import com.android.customization.picker.color.data.repository.FakeColorPickerRepository
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
 import com.android.customization.picker.color.shared.model.ColorType
+import com.android.wallpaper.testing.FakeSnapshotStore
 import com.android.wallpaper.testing.collectLastValue
 import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.runBlocking
 import kotlinx.coroutines.test.runTest
 import org.junit.Before
 import org.junit.Test
@@ -36,16 +39,26 @@
 @RunWith(JUnit4::class)
 class ColorPickerInteractorTest {
     private lateinit var underTest: ColorPickerInteractor
+    private lateinit var repository: FakeColorPickerRepository
+    private lateinit var store: FakeSnapshotStore
 
     private lateinit var context: Context
 
     @Before
     fun setUp() {
         context = InstrumentationRegistry.getInstrumentation().targetContext
+        repository = FakeColorPickerRepository(context = context)
+        store = FakeSnapshotStore()
         underTest =
             ColorPickerInteractor(
-                repository = FakeColorPickerRepository(context = context),
+                repository = repository,
+                snapshotRestorer = {
+                    ColorPickerSnapshotRestorer(interactor = underTest).apply {
+                        runBlocking { setUpSnapshotRestorer(store = store) }
+                    }
+                },
             )
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 0)
     }
 
     @Test
@@ -66,4 +79,40 @@
         val presetColorOptionModelAfter = colorOptions()?.get(ColorType.BASIC_COLOR)?.get(1)
         assertThat(presetColorOptionModelAfter?.isSelected).isTrue()
     }
+
+    @Test
+    fun snapshotRestorer_updatesSnapshot() = runTest {
+        val colorOptions = collectLastValue(underTest.colorOptions)
+        val wallpaperColorOptionModel0 = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(0)
+        val wallpaperColorOptionModel1 = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(1)
+        assertThat(wallpaperColorOptionModel0?.isSelected).isTrue()
+        assertThat(wallpaperColorOptionModel1?.isSelected).isFalse()
+
+        val storedSnapshot = store.retrieve()
+        wallpaperColorOptionModel1?.let { underTest.select(it) }
+        val wallpaperColorOptionModel0After = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(0)
+        val wallpaperColorOptionModel1After = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(1)
+        assertThat(wallpaperColorOptionModel0After?.isSelected).isFalse()
+        assertThat(wallpaperColorOptionModel1After?.isSelected).isTrue()
+
+        assertThat(store.retrieve()).isNotEqualTo(storedSnapshot)
+    }
+
+    @Test
+    fun snapshotRestorer_doesNotUpdateSnapshotOnExternalUpdates() = runTest {
+        val colorOptions = collectLastValue(underTest.colorOptions)
+        val wallpaperColorOptionModel0 = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(0)
+        val wallpaperColorOptionModel1 = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(1)
+        assertThat(wallpaperColorOptionModel0?.isSelected).isTrue()
+        assertThat(wallpaperColorOptionModel1?.isSelected).isFalse()
+
+        val storedSnapshot = store.retrieve()
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 1)
+        val wallpaperColorOptionModel0After = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(0)
+        val wallpaperColorOptionModel1After = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(1)
+        assertThat(wallpaperColorOptionModel0After?.isSelected).isFalse()
+        assertThat(wallpaperColorOptionModel1After?.isSelected).isTrue()
+
+        assertThat(store.retrieve()).isEqualTo(storedSnapshot)
+    }
 }
diff --git a/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerSnapshotRestorerTest.kt b/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerSnapshotRestorerTest.kt
new file mode 100644
index 0000000..27b8550
--- /dev/null
+++ b/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerSnapshotRestorerTest.kt
@@ -0,0 +1,138 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package com.android.customization.model.picker.color.domain.interactor
+
+import android.content.Context
+import androidx.test.filters.SmallTest
+import androidx.test.platform.app.InstrumentationRegistry
+import com.android.customization.picker.color.data.repository.FakeColorPickerRepository
+import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
+import com.android.customization.picker.color.shared.model.ColorOptionModel
+import com.android.customization.picker.color.shared.model.ColorType
+import com.android.wallpaper.testing.FakeSnapshotStore
+import com.android.wallpaper.testing.collectLastValue
+import com.google.common.truth.Truth
+import com.google.common.truth.Truth.assertThat
+import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.test.runTest
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+@OptIn(ExperimentalCoroutinesApi::class)
+@SmallTest
+@RunWith(JUnit4::class)
+class ColorPickerSnapshotRestorerTest {
+
+    private lateinit var underTest: ColorPickerSnapshotRestorer
+    private lateinit var repository: FakeColorPickerRepository
+    private lateinit var store: FakeSnapshotStore
+
+    private lateinit var context: Context
+
+    @Before
+    fun setUp() {
+        context = InstrumentationRegistry.getInstrumentation().targetContext
+        repository = FakeColorPickerRepository(context = context)
+        underTest =
+            ColorPickerSnapshotRestorer(
+                interactor =
+                    ColorPickerInteractor(
+                        repository = repository,
+                        snapshotRestorer = { underTest },
+                    )
+            )
+        store = FakeSnapshotStore()
+    }
+
+    @Test
+    fun restoreToSnapshot_noCallsToStore_restoresToInitialSnapshot() = runTest {
+        val colorOptions = collectLastValue(repository.colorOptions)
+
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 2)
+        val initialSnapshot = underTest.setUpSnapshotRestorer(store = store)
+        assertThat(initialSnapshot.args).isNotEmpty()
+
+        val colorOptionToSelect = colorOptions()?.get(ColorType.BASIC_COLOR)?.get(3)
+        colorOptionToSelect?.let { repository.select(it) }
+        assertState(colorOptions(), ColorType.BASIC_COLOR, 3)
+
+        underTest.restoreToSnapshot(initialSnapshot)
+        assertState(colorOptions(), ColorType.WALLPAPER_COLOR, 2)
+    }
+
+    @Test
+    fun restoreToSnapshot_withCallToStore_restoresToInitialSnapshot() = runTest {
+        val colorOptions = collectLastValue(repository.colorOptions)
+
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 2)
+        val initialSnapshot = underTest.setUpSnapshotRestorer(store = store)
+        assertThat(initialSnapshot.args).isNotEmpty()
+
+        val colorOptionToSelect = colorOptions()?.get(ColorType.BASIC_COLOR)?.get(3)
+        colorOptionToSelect?.let { repository.select(it) }
+        assertState(colorOptions(), ColorType.BASIC_COLOR, 3)
+
+        val colorOptionToStore = colorOptions()?.get(ColorType.BASIC_COLOR)?.get(1)
+        colorOptionToStore?.let { underTest.storeSnapshot(colorOptionToStore) }
+
+        underTest.restoreToSnapshot(initialSnapshot)
+        assertState(colorOptions(), ColorType.WALLPAPER_COLOR, 2)
+    }
+
+    private fun assertState(
+        colorOptions: Map<ColorType, List<ColorOptionModel>>?,
+        selectedColorType: ColorType,
+        selectedColorIndex: Int
+    ) {
+        var foundSelectedColorOption = false
+        assertThat(colorOptions).isNotNull()
+        val optionsOfSelectedColorType = colorOptions?.get(selectedColorType)
+        assertThat(optionsOfSelectedColorType).isNotNull()
+        if (optionsOfSelectedColorType != null) {
+            for (i in optionsOfSelectedColorType.indices) {
+                val colorOptionHasSelectedIndex = i == selectedColorIndex
+                Truth.assertWithMessage(
+                        "Expected color option with index \"${i}\" to have" +
+                            " isSelected=$colorOptionHasSelectedIndex but it was" +
+                            " ${optionsOfSelectedColorType[i].isSelected}, num options: ${colorOptions.size}"
+                    )
+                    .that(optionsOfSelectedColorType[i].isSelected)
+                    .isEqualTo(colorOptionHasSelectedIndex)
+                foundSelectedColorOption = foundSelectedColorOption || colorOptionHasSelectedIndex
+            }
+            if (selectedColorIndex == -1) {
+                Truth.assertWithMessage(
+                        "Expected no color options to be selected, but a color option is" +
+                            " selected"
+                    )
+                    .that(foundSelectedColorOption)
+                    .isFalse()
+            } else {
+                Truth.assertWithMessage(
+                        "Expected a color option to be selected, but no color option is" +
+                            " selected"
+                    )
+                    .that(foundSelectedColorOption)
+                    .isTrue()
+            }
+        }
+    }
+}
diff --git a/tests/src/com/android/customization/model/picker/color/ui/viewmodel/ColorPickerViewModelTest.kt b/tests/src/com/android/customization/model/picker/color/ui/viewmodel/ColorPickerViewModelTest.kt
index 6e5f776..b7567ed 100644
--- a/tests/src/com/android/customization/model/picker/color/ui/viewmodel/ColorPickerViewModelTest.kt
+++ b/tests/src/com/android/customization/model/picker/color/ui/viewmodel/ColorPickerViewModelTest.kt
@@ -21,15 +21,24 @@
 import androidx.test.platform.app.InstrumentationRegistry
 import com.android.customization.picker.color.data.repository.FakeColorPickerRepository
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
 import com.android.customization.picker.color.shared.model.ColorType
 import com.android.customization.picker.color.ui.viewmodel.ColorOptionViewModel
 import com.android.customization.picker.color.ui.viewmodel.ColorPickerViewModel
 import com.android.customization.picker.color.ui.viewmodel.ColorTypeViewModel
+import com.android.wallpaper.testing.FakeSnapshotStore
 import com.android.wallpaper.testing.collectLastValue
 import com.google.common.truth.Truth.assertThat
 import com.google.common.truth.Truth.assertWithMessage
+import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.runBlocking
+import kotlinx.coroutines.test.StandardTestDispatcher
+import kotlinx.coroutines.test.TestScope
+import kotlinx.coroutines.test.resetMain
 import kotlinx.coroutines.test.runTest
+import kotlinx.coroutines.test.setMain
+import org.junit.After
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -40,80 +49,111 @@
 @RunWith(JUnit4::class)
 class ColorPickerViewModelTest {
     private lateinit var underTest: ColorPickerViewModel
+    private lateinit var repository: FakeColorPickerRepository
+    private lateinit var interactor: ColorPickerInteractor
+    private lateinit var store: FakeSnapshotStore
 
     private lateinit var context: Context
+    private lateinit var testScope: TestScope
 
     @Before
     fun setUp() {
         context = InstrumentationRegistry.getInstrumentation().targetContext
+        val testDispatcher = StandardTestDispatcher()
+        testScope = TestScope(testDispatcher)
+        Dispatchers.setMain(testDispatcher)
+        repository = FakeColorPickerRepository(context = context)
+        store = FakeSnapshotStore()
+
+        interactor =
+            ColorPickerInteractor(
+                repository = repository,
+                snapshotRestorer = {
+                    ColorPickerSnapshotRestorer(interactor = interactor).apply {
+                        runBlocking { setUpSnapshotRestorer(store = store) }
+                    }
+                },
+            )
 
         underTest =
-            ColorPickerViewModel.Factory(
-                    context = context,
-                    interactor =
-                        ColorPickerInteractor(
-                            repository = FakeColorPickerRepository(context = context),
-                        ),
-                )
+            ColorPickerViewModel.Factory(context = context, interactor = interactor)
                 .create(ColorPickerViewModel::class.java)
+
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 0)
+    }
+
+    @After
+    fun tearDown() {
+        Dispatchers.resetMain()
     }
 
     @Test
-    fun `Select a color section color`() = runTest {
-        val colorSectionOptions = collectLastValue(underTest.colorSectionOptions)
+    fun `Select a color section color`() =
+        testScope.runTest {
+            val colorSectionOptions = collectLastValue(underTest.colorSectionOptions)
 
-        assertColorOptionUiState(colorOptions = colorSectionOptions(), selectedColorOptionIndex = 0)
+            assertColorOptionUiState(
+                colorOptions = colorSectionOptions(),
+                selectedColorOptionIndex = 0
+            )
 
-        colorSectionOptions()?.get(2)?.onClick?.invoke()
-        assertColorOptionUiState(colorOptions = colorSectionOptions(), selectedColorOptionIndex = 2)
+            colorSectionOptions()?.get(2)?.onClick?.invoke()
+            assertColorOptionUiState(
+                colorOptions = colorSectionOptions(),
+                selectedColorOptionIndex = 2
+            )
 
-        colorSectionOptions()?.get(4)?.onClick?.invoke()
-        assertColorOptionUiState(colorOptions = colorSectionOptions(), selectedColorOptionIndex = 4)
-    }
+            colorSectionOptions()?.get(4)?.onClick?.invoke()
+            assertColorOptionUiState(
+                colorOptions = colorSectionOptions(),
+                selectedColorOptionIndex = 4
+            )
+        }
 
     @Test
-    fun `Select a preset color`() = runTest {
-        val colorTypes = collectLastValue(underTest.colorTypes)
-        val colorOptions = collectLastValue(underTest.colorOptions)
+    fun `Select a preset color`() =
+        testScope.runTest {
+            val colorTypes = collectLastValue(underTest.colorTypes)
+            val colorOptions = collectLastValue(underTest.colorOptions)
 
-        // Initially, the wallpaper color tab should be selected
-        assertPickerUiState(
-            colorTypes = colorTypes(),
-            colorOptions = colorOptions(),
-            selectedColorTypeText = "Wallpaper colors",
-            selectedColorOptionIndex = 0
-        )
+            // Initially, the wallpaper color tab should be selected
+            assertPickerUiState(
+                colorTypes = colorTypes(),
+                colorOptions = colorOptions(),
+                selectedColorTypeText = "Wallpaper colors",
+                selectedColorOptionIndex = 0
+            )
 
-        // Select "Basic colors" tab
-        colorTypes()?.get(ColorType.BASIC_COLOR)?.onClick?.invoke()
-        assertPickerUiState(
-            colorTypes = colorTypes(),
-            colorOptions = colorOptions(),
-            selectedColorTypeText = "Basic colors",
-            selectedColorOptionIndex = -1
-        )
+            // Select "Basic colors" tab
+            colorTypes()?.get(ColorType.BASIC_COLOR)?.onClick?.invoke()
+            assertPickerUiState(
+                colorTypes = colorTypes(),
+                colorOptions = colorOptions(),
+                selectedColorTypeText = "Basic colors",
+                selectedColorOptionIndex = -1
+            )
 
-        // Select a color option
-        colorOptions()?.get(2)?.onClick?.invoke()
+            // Select a color option
+            colorOptions()?.get(2)?.onClick?.invoke()
 
-        // Check original option is no longer selected
-        colorTypes()?.get(ColorType.WALLPAPER_COLOR)?.onClick?.invoke()
-        assertPickerUiState(
-            colorTypes = colorTypes(),
-            colorOptions = colorOptions(),
-            selectedColorTypeText = "Wallpaper colors",
-            selectedColorOptionIndex = -1
-        )
+            // Check original option is no longer selected
+            colorTypes()?.get(ColorType.WALLPAPER_COLOR)?.onClick?.invoke()
+            assertPickerUiState(
+                colorTypes = colorTypes(),
+                colorOptions = colorOptions(),
+                selectedColorTypeText = "Wallpaper colors",
+                selectedColorOptionIndex = -1
+            )
 
-        // Check new option is selected
-        colorTypes()?.get(ColorType.BASIC_COLOR)?.onClick?.invoke()
-        assertPickerUiState(
-            colorTypes = colorTypes(),
-            colorOptions = colorOptions(),
-            selectedColorTypeText = "Basic colors",
-            selectedColorOptionIndex = 2
-        )
-    }
+            // Check new option is selected
+            colorTypes()?.get(ColorType.BASIC_COLOR)?.onClick?.invoke()
+            assertPickerUiState(
+                colorTypes = colorTypes(),
+                colorOptions = colorOptions(),
+                selectedColorTypeText = "Basic colors",
+                selectedColorOptionIndex = 2
+            )
+        }
 
     /**
      * Asserts the entire picker UI state is what is expected. This includes the color type tabs and
diff --git a/tests/src/com/android/customization/testing/TestCustomizationInjector.kt b/tests/src/com/android/customization/testing/TestCustomizationInjector.kt
index 3ab7c84..2a2ab5e 100644
--- a/tests/src/com/android/customization/testing/TestCustomizationInjector.kt
+++ b/tests/src/com/android/customization/testing/TestCustomizationInjector.kt
@@ -18,6 +18,7 @@
 import com.android.customization.picker.clock.ui.viewmodel.ClockSettingsViewModel
 import com.android.customization.picker.color.data.repository.ColorPickerRepositoryImpl
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
 import com.android.customization.picker.color.ui.viewmodel.ColorPickerViewModel
 import com.android.customization.picker.quickaffordance.data.repository.KeyguardQuickAffordancePickerRepository
 import com.android.customization.picker.quickaffordance.domain.interactor.KeyguardQuickAffordancePickerInteractor
@@ -54,6 +55,7 @@
     private var clockViewFactory: ClockViewFactory? = null
     private var colorPickerInteractor: ColorPickerInteractor? = null
     private var colorPickerViewModelFactory: ColorPickerViewModel.Factory? = null
+    private var colorPickerSnapshotRestorer: ColorPickerSnapshotRestorer? = null
     private var clockCarouselViewModel: ClockCarouselViewModel? = null
     private var clockSettingsViewModelFactory: ClockSettingsViewModel.Factory? = null
 
@@ -118,6 +120,8 @@
         val restorers: MutableMap<Int, SnapshotRestorer> = HashMap()
         restorers[KEY_QUICK_AFFORDANCE_SNAPSHOT_RESTORER] =
             getKeyguardQuickAffordanceSnapshotRestorer(context)
+        restorers[KEY_COLOR_PICKER_SNAPSHOT_RESTORER] =
+            getColorPickerSnapshotRestorer(context, getWallpaperColorsViewModel())
         return restorers
     }
 
@@ -168,7 +172,12 @@
         wallpaperColorsViewModel: WallpaperColorsViewModel,
     ): ColorPickerInteractor {
         return colorPickerInteractor
-            ?: ColorPickerInteractor(ColorPickerRepositoryImpl(context, wallpaperColorsViewModel))
+            ?: ColorPickerInteractor(
+                    repository = ColorPickerRepositoryImpl(context, wallpaperColorsViewModel),
+                    snapshotRestorer = {
+                        getColorPickerSnapshotRestorer(context, wallpaperColorsViewModel)
+                    },
+                )
                 .also { colorPickerInteractor = it }
     }
 
@@ -184,6 +193,17 @@
                 .also { colorPickerViewModelFactory = it }
     }
 
+    private fun getColorPickerSnapshotRestorer(
+        context: Context,
+        wallpaperColorsViewModel: WallpaperColorsViewModel
+    ): ColorPickerSnapshotRestorer {
+        return colorPickerSnapshotRestorer
+            ?: ColorPickerSnapshotRestorer(
+                    getColorPickerInteractor(context, wallpaperColorsViewModel)
+                )
+                .also { colorPickerSnapshotRestorer = it }
+    }
+
     override fun getClockCarouselViewModel(context: Context): ClockCarouselViewModel {
         return clockCarouselViewModel
             ?: ClockCarouselViewModel(getClockPickerInteractor(context)).also {
@@ -209,5 +229,7 @@
 
     companion object {
         private const val KEY_QUICK_AFFORDANCE_SNAPSHOT_RESTORER = 1
+        private const val KEY_COLOR_PICKER_SNAPSHOT_RESTORER =
+            KEY_QUICK_AFFORDANCE_SNAPSHOT_RESTORER + 1
     }
 }