diff --git a/services/core/java/com/android/server/pm/ApexManager.java b/services/core/java/com/android/server/pm/ApexManager.java
index 307a07b..a009183 100644
--- a/services/core/java/com/android/server/pm/ApexManager.java
+++ b/services/core/java/com/android/server/pm/ApexManager.java
@@ -32,10 +32,13 @@
 import android.content.pm.PackageInstaller;
 import android.content.pm.PackageManager;
 import android.content.pm.PackageParser;
+import android.content.pm.parsing.AndroidPackage;
 import android.os.Environment;
 import android.os.RemoteException;
 import android.os.ServiceManager;
 import android.sysprop.ApexProperties;
+import android.util.ArrayMap;
+import android.util.ArraySet;
 import android.util.Singleton;
 import android.util.Slog;
 
@@ -44,15 +47,20 @@
 import com.android.internal.os.BackgroundThread;
 import com.android.internal.util.IndentingPrintWriter;
 
+import com.google.android.collect.Lists;
+
 import java.io.File;
 import java.io.PrintWriter;
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 /**
@@ -97,12 +105,27 @@
      * Minimal information about APEX mount points and the original APEX package they refer to.
      */
     static class ActiveApexInfo {
+        @Nullable public final String apexModuleName;
         public final File apexDirectory;
-        public final File preinstalledApexPath;
+        public final File preInstalledApexPath;
 
-        private ActiveApexInfo(File apexDirectory, File preinstalledApexPath) {
+        private ActiveApexInfo(File apexDirectory, File preInstalledApexPath) {
+            this(null, apexDirectory, preInstalledApexPath);
+        }
+
+        private ActiveApexInfo(@Nullable String apexModuleName, File apexDirectory,
+                File preInstalledApexPath) {
+            this.apexModuleName = apexModuleName;
             this.apexDirectory = apexDirectory;
-            this.preinstalledApexPath = preinstalledApexPath;
+            this.preInstalledApexPath = preInstalledApexPath;
+        }
+
+        private ActiveApexInfo(ApexInfo apexInfo) {
+            this(
+                    apexInfo.moduleName,
+                    new File(Environment.getApexDirectory() + File.separator
+                            + apexInfo.moduleName),
+                    new File(apexInfo.preinstalledModulePath));
         }
     }
 
@@ -232,6 +255,17 @@
     abstract boolean uninstallApex(String apexPackagePath);
 
     /**
+     * Registers an APK package as an embedded apk of apex.
+     */
+    abstract void registerApkInApex(AndroidPackage pkg);
+
+    /**
+     * Returns list of {@code packageName} of apks inside the given apex.
+     * @param apexPackageName Package name of the apk container of apex
+     */
+    abstract List<String> getApksInApex(String apexPackageName);
+
+    /**
      * Dumps various state information to the provided {@link PrintWriter} object.
      *
      * @param pw the {@link PrintWriter} object to send information to.
@@ -255,16 +289,33 @@
     static class ApexManagerImpl extends ApexManager {
         private final IApexService mApexService;
         private final Object mLock = new Object();
+
+        @GuardedBy("mLock")
+        private Set<ActiveApexInfo> mActiveApexInfosCache;
+
         /**
-         * A map from {@code APEX packageName} to the {@Link PackageInfo} generated from the {@code
-         * AndroidManifest.xml}
-         *
-         * <p>Note that key of this map is {@code packageName} field of the corresponding {@code
-         * AndroidManifest.xml}.
-          */
+         * Contains the list of {@code packageName}s of apks-in-apex for given
+         * {@code apexModuleName}. See {@link #mPackageNameToApexModuleName} to understand the
+         * difference between {@code packageName} and {@code apexModuleName}.
+         */
+        @GuardedBy("mLock")
+        private Map<String, List<String>> mApksInApex = new ArrayMap<>();
+
         @GuardedBy("mLock")
         private List<PackageInfo> mAllPackagesCache;
 
+        /**
+         * An APEX is a file format that delivers the apex-payload wrapped in an apk container. The
+         * apk container has a reference name, called {@code packageName}, which is found inside the
+         * {@code AndroidManifest.xml}. The apex payload inside the container also has a reference
+         * name, called {@code apexModuleName}, which is found in {@code apex_manifest.json} file.
+         *
+         * {@link #mPackageNameToApexModuleName} contains the mapping from {@code packageName} of
+         * the apk container to {@code apexModuleName} of the apex-payload inside.
+         */
+        @GuardedBy("mLock")
+        private Map<String, String> mPackageNameToApexModuleName;
+
         ApexManagerImpl(IApexService apexService) {
             mApexService = apexService;
         }
@@ -291,18 +342,25 @@
 
         @Override
         List<ActiveApexInfo> getActiveApexInfos() {
-            try {
-                return Arrays.stream(mApexService.getActivePackages())
-                        .map(apexInfo -> new ActiveApexInfo(
-                                new File(
-                                Environment.getApexDirectory() + File.separator
-                                        + apexInfo.moduleName),
-                                new File(apexInfo.preinstalledModulePath))).collect(
-                                Collectors.toList());
-            } catch (RemoteException e) {
-                Slog.e(TAG, "Unable to retrieve packages from apexservice", e);
+            synchronized (mLock) {
+                if (mActiveApexInfosCache == null) {
+                    try {
+                        mActiveApexInfosCache = new ArraySet<>();
+                        final ApexInfo[] activePackages = mApexService.getActivePackages();
+                        for (int i = 0; i < activePackages.length; i++) {
+                            ApexInfo apexInfo = activePackages[i];
+                            mActiveApexInfosCache.add(new ActiveApexInfo(apexInfo));
+                        }
+                    } catch (RemoteException e) {
+                        Slog.e(TAG, "Unable to retrieve packages from apexservice", e);
+                    }
+                }
+                if (mActiveApexInfosCache != null) {
+                    return new ArrayList<>(mActiveApexInfosCache);
+                } else {
+                    return Collections.emptyList();
+                }
             }
-            return Collections.emptyList();
         }
 
         @Override
@@ -325,6 +383,7 @@
                 }
                 try {
                     mAllPackagesCache = new ArrayList<>();
+                    mPackageNameToApexModuleName = new HashMap<>();
                     HashSet<String> activePackagesSet = new HashSet<>();
                     HashSet<String> factoryPackagesSet = new HashSet<>();
                     final ApexInfo[] allPkgs = mApexService.getAllPackages();
@@ -350,6 +409,7 @@
                         final PackageInfo packageInfo =
                                 PackageParser.generatePackageInfo(pkg, ai, flags);
                         mAllPackagesCache.add(packageInfo);
+                        mPackageNameToApexModuleName.put(packageInfo.packageName, ai.moduleName);
                         if (ai.isActive) {
                             if (activePackagesSet.contains(packageInfo.packageName)) {
                                 throw new IllegalStateException(
@@ -366,7 +426,6 @@
                             }
                             factoryPackagesSet.add(packageInfo.packageName);
                         }
-
                     }
                 } catch (RemoteException re) {
                     Slog.e(TAG, "Unable to retrieve packages from apexservice: " + re.toString());
@@ -533,6 +592,37 @@
             }
         }
 
+        @Override
+        void registerApkInApex(AndroidPackage pkg) {
+            synchronized (mLock) {
+                final Iterator<ActiveApexInfo> it = mActiveApexInfosCache.iterator();
+                while (it.hasNext()) {
+                    final ActiveApexInfo aai = it.next();
+                    if (pkg.getBaseCodePath().startsWith(aai.apexDirectory.getAbsolutePath())) {
+                        List<String> apks = mApksInApex.get(aai.apexModuleName);
+                        if (apks == null) {
+                            apks = Lists.newArrayList();
+                            mApksInApex.put(aai.apexModuleName, apks);
+                        }
+                        apks.add(pkg.getPackageName());
+                    }
+                }
+            }
+        }
+
+        @Override
+        List<String> getApksInApex(String apexPackageName) {
+            // TODO(b/142712057): Avoid calling populateAllPackagesCacheIfNeeded during boot.
+            populateAllPackagesCacheIfNeeded();
+            synchronized (mLock) {
+                String moduleName = mPackageNameToApexModuleName.get(apexPackageName);
+                if (moduleName == null) {
+                    return Collections.emptyList();
+                }
+                return mApksInApex.getOrDefault(moduleName, Collections.emptyList());
+            }
+        }
+
         /**
          * Dump information about the packages contained in a particular cache
          * @param packagesCache the cache to print information about.
@@ -614,7 +704,6 @@
      * updating APEX packages.
      */
     private static final class ApexManagerFlattenedApex extends ApexManager {
-
         @Override
         List<ActiveApexInfo> getActiveApexInfos() {
             // There is no apexd running in case of flattened apex
@@ -721,6 +810,16 @@
         }
 
         @Override
+        void registerApkInApex(AndroidPackage pkg) {
+            // No-op
+        }
+
+        @Override
+        List<String> getApksInApex(String apexPackageName) {
+            return Collections.emptyList();
+        }
+
+        @Override
         void dump(PrintWriter pw, String packageName) {
             // No-op
         }
diff --git a/services/core/java/com/android/server/pm/PackageInstallerService.java b/services/core/java/com/android/server/pm/PackageInstallerService.java
index c43f234..6331dd4 100644
--- a/services/core/java/com/android/server/pm/PackageInstallerService.java
+++ b/services/core/java/com/android/server/pm/PackageInstallerService.java
@@ -188,7 +188,7 @@
         }
     };
 
-    public PackageInstallerService(Context context, PackageManagerService pm, ApexManager am) {
+    public PackageInstallerService(Context context, PackageManagerService pm) {
         mContext = context;
         mPm = pm;
         mPermissionManager = LocalServices.getService(PermissionManagerServiceInternal.class);
@@ -206,9 +206,8 @@
         mSessionsDir = new File(Environment.getDataSystemDirectory(), "install_sessions");
         mSessionsDir.mkdirs();
 
-        mApexManager = am;
-
-        mStagingManager = new StagingManager(this, am, context);
+        mApexManager = ApexManager.getInstance();
+        mStagingManager = new StagingManager(this, context);
     }
 
     boolean okToSendBroadcasts()  {
diff --git a/services/core/java/com/android/server/pm/PackageManagerService.java b/services/core/java/com/android/server/pm/PackageManagerService.java
index a9571d9..52e8e37 100644
--- a/services/core/java/com/android/server/pm/PackageManagerService.java
+++ b/services/core/java/com/android/server/pm/PackageManagerService.java
@@ -492,6 +492,7 @@
     static final int SCAN_AS_PRODUCT = 1 << 20;
     static final int SCAN_AS_SYSTEM_EXT = 1 << 21;
     static final int SCAN_AS_ODM = 1 << 22;
+    static final int SCAN_AS_APK_IN_APEX = 1 << 23;
 
     @IntDef(flag = true, prefix = { "SCAN_" }, value = {
             SCAN_NO_DEX,
@@ -2589,6 +2590,9 @@
                     & (SCAN_AS_VENDOR | SCAN_AS_ODM | SCAN_AS_PRODUCT | SCAN_AS_SYSTEM_EXT)) != 0) {
                 return true;
             }
+            if ((scanFlags & SCAN_AS_APK_IN_APEX) != 0) {
+                return true;
+            }
             return false;
         }
 
@@ -3332,7 +3336,7 @@
                 }
             }
 
-            mInstallerService = new PackageInstallerService(mContext, this, mApexManager);
+            mInstallerService = new PackageInstallerService(mContext, this);
             final Pair<ComponentName, String> instantAppResolverComponent =
                     getInstantAppResolverLPr();
             if (instantAppResolverComponent != null) {
@@ -11710,6 +11714,9 @@
             mSettings.insertPackageSettingLPw(pkgSetting, pkg);
             // Add the new setting to mPackages
             mPackages.put(pkg.getAppInfoPackageName(), pkg);
+            if ((scanFlags & SCAN_AS_APK_IN_APEX) != 0) {
+                mApexManager.registerApkInApex(pkg);
+            }
 
             // Add the package's KeySets to the global KeySetManagerService
             KeySetManagerService ksms = mSettings.mKeySetManagerService;
@@ -17757,10 +17764,10 @@
             ApexManager.ActiveApexInfo apexInfo) {
         for (int i = 0, size = SYSTEM_PARTITIONS.size(); i < size; i++) {
             SystemPartition sp = SYSTEM_PARTITIONS.get(i);
-            if (apexInfo.preinstalledApexPath.getAbsolutePath().startsWith(
+            if (apexInfo.preInstalledApexPath.getAbsolutePath().startsWith(
                     sp.folder.getAbsolutePath())) {
-                return new SystemPartition(apexInfo.apexDirectory, sp.scanFlag,
-                        false /* hasOverlays */);
+                return new SystemPartition(apexInfo.apexDirectory,
+                        sp.scanFlag | SCAN_AS_APK_IN_APEX, false /* hasOverlays */);
             }
         }
         return null;
@@ -23452,6 +23459,11 @@
         }
 
         @Override
+        public List<String> getApksInApex(String apexPackageName) {
+            return PackageManagerService.this.mApexManager.getApksInApex(apexPackageName);
+        }
+
+        @Override
         public void uninstallApex(String packageName, long versionCode, int userId,
                 IntentSender intentSender, int flags) {
             final int callerUid = Binder.getCallingUid();
diff --git a/services/core/java/com/android/server/pm/StagingManager.java b/services/core/java/com/android/server/pm/StagingManager.java
index 9e462cd..2265d01 100644
--- a/services/core/java/com/android/server/pm/StagingManager.java
+++ b/services/core/java/com/android/server/pm/StagingManager.java
@@ -33,11 +33,13 @@
 import android.content.pm.PackageInstaller;
 import android.content.pm.PackageInstaller.SessionInfo;
 import android.content.pm.PackageManager;
+import android.content.pm.PackageManagerInternal;
 import android.content.pm.PackageParser;
 import android.content.pm.PackageParser.PackageParserException;
 import android.content.pm.PackageParser.SigningDetails;
 import android.content.pm.PackageParser.SigningDetails.SignatureSchemeVersion;
 import android.content.pm.ParceledListSlice;
+import android.content.pm.parsing.AndroidPackage;
 import android.content.rollback.IRollbackManager;
 import android.content.rollback.RollbackInfo;
 import android.content.rollback.RollbackManager;
@@ -50,6 +52,8 @@
 import android.os.PowerManager;
 import android.os.RemoteException;
 import android.os.ServiceManager;
+import android.os.UserHandle;
+import android.os.UserManagerInternal;
 import android.os.storage.IStorageManager;
 import android.os.storage.StorageManager;
 import android.util.IntArray;
@@ -61,6 +65,7 @@
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.content.PackageHelper;
 import com.android.internal.os.BackgroundThread;
+import com.android.server.LocalServices;
 
 import java.io.File;
 import java.io.IOException;
@@ -93,10 +98,11 @@
     @GuardedBy("mStagedSessions")
     private final SparseIntArray mSessionRollbackIds = new SparseIntArray();
 
-    StagingManager(PackageInstallerService pi, ApexManager am, Context context) {
+    StagingManager(PackageInstallerService pi, Context context) {
         mPi = pi;
-        mApexManager = am;
         mContext = context;
+
+        mApexManager = ApexManager.getInstance();
         mPowerManager = (PowerManager) context.getSystemService(Context.POWER_SERVICE);
         mPreRebootVerificationHandler = new PreRebootVerificationHandler(
                 BackgroundThread.get().getLooper());
@@ -334,6 +340,88 @@
         return PackageHelper.getStorageManager().needsCheckpoint();
     }
 
+    /**
+     * Apks inside apex are not installed using apk-install flow. They are scanned from the system
+     * directory directly by PackageManager, as such, RollbackManager need to handle their data
+     * separately here.
+     */
+    private void snapshotAndRestoreApkInApexUserData(PackageInstallerSession session) {
+        // We want to process apks inside apex. So current session needs to contain apex.
+        if (!sessionContainsApex(session)) {
+            return;
+        }
+
+        boolean doSnapshotOrRestore =
+                (session.params.installFlags & PackageManager.INSTALL_ENABLE_ROLLBACK) != 0
+                || session.params.installReason == PackageManager.INSTALL_REASON_ROLLBACK;
+        if (!doSnapshotOrRestore) {
+            return;
+        }
+
+        // Find all the apex sessions that needs processing
+        List<PackageInstallerSession> apexSessions = new ArrayList<>();
+        if (session.isMultiPackage()) {
+            List<PackageInstallerSession> childrenSessions = new ArrayList<>();
+            synchronized (mStagedSessions) {
+                for (int childSessionId : session.getChildSessionIds()) {
+                    PackageInstallerSession childSession = mStagedSessions.get(childSessionId);
+                    if (childSession != null) {
+                        childrenSessions.add(childSession);
+                    }
+                }
+            }
+            for (PackageInstallerSession childSession : childrenSessions) {
+                if (sessionContainsApex(childSession)) {
+                    apexSessions.add(childSession);
+                }
+            }
+        } else {
+            apexSessions.add(session);
+        }
+
+        // For each apex, process the apks inside it
+        for (PackageInstallerSession apexSession : apexSessions) {
+            List<String> apksInApex = mApexManager.getApksInApex(apexSession.getPackageName());
+            for (String apk: apksInApex) {
+                snapshotAndRestoreApkInApexUserData(apk);
+            }
+        }
+    }
+
+    private void snapshotAndRestoreApkInApexUserData(String packageName) {
+        IRollbackManager rm = IRollbackManager.Stub.asInterface(
+                    ServiceManager.getService(Context.ROLLBACK_SERVICE));
+
+        PackageManagerInternal mPmi = LocalServices.getService(PackageManagerInternal.class);
+        AndroidPackage pkg = mPmi.getPackage(packageName);
+        if (pkg == null) {
+            Slog.e(TAG, "Could not find package: " + packageName
+                    + "for snapshotting/restoring user data.");
+            return;
+        }
+        final String seInfo = pkg.getSeInfo();
+        final UserManagerInternal um = LocalServices.getService(UserManagerInternal.class);
+        final int[] allUsers = um.getUserIds();
+
+        int appId = -1;
+        long ceDataInode = -1;
+        final PackageSetting ps = (PackageSetting) mPmi.getPackageSetting(packageName);
+        if (ps != null && rm != null) {
+            appId = ps.appId;
+            ceDataInode = ps.getCeDataInode(UserHandle.USER_SYSTEM);
+            // NOTE: We ignore the user specified in the InstallParam because we know this is
+            // an update, and hence need to restore data for all installed users.
+            final int[] installedUsers = ps.queryInstalledUsers(allUsers, true);
+
+            try {
+                rm.snapshotAndRestoreUserData(packageName, installedUsers, appId, ceDataInode,
+                        seInfo, 0 /*token*/);
+            } catch (RemoteException re) {
+                Slog.e(TAG, "Error snapshotting/restoring user data: " + re);
+            }
+        }
+    }
+
     private void resumeSession(@NonNull PackageInstallerSession session) {
         Slog.d(TAG, "Resuming session " + session.sessionId);
 
@@ -407,6 +495,7 @@
                 abortCheckpoint();
                 return;
             }
+            snapshotAndRestoreApkInApexUserData(session);
             Slog.i(TAG, "APEX packages in session " + session.sessionId
                     + " were successfully activated. Proceeding with APK packages, if any");
         }
@@ -529,7 +618,7 @@
                         Arrays.stream(session.getChildSessionIds())
                                 // Retrieve cached sessions matching ids.
                                 .mapToObj(i -> mStagedSessions.get(i))
-                                // Filter only the ones containing APKs.s
+                                // Filter only the ones containing APKs.
                                 .filter(childSession -> !isApexSession(childSession))
                                 .collect(Collectors.toList());
             }
diff --git a/services/core/java/com/android/server/rollback/Rollback.java b/services/core/java/com/android/server/rollback/Rollback.java
index 88c1564..d719afb 100644
--- a/services/core/java/com/android/server/rollback/Rollback.java
+++ b/services/core/java/com/android/server/rollback/Rollback.java
@@ -323,8 +323,8 @@
                 new VersionedPackage(packageName, newVersion),
                 new VersionedPackage(packageName, installedVersion),
                 new IntArray() /* pendingBackups */, new ArrayList<>() /* pendingRestores */,
-                isApex, new IntArray(), new SparseLongArray() /* ceSnapshotInodes */,
-                rollbackDataPolicy);
+                isApex, false /* isApkInApex */, new IntArray(),
+                new SparseLongArray() /* ceSnapshotInodes */, rollbackDataPolicy);
 
         synchronized (mLock) {
             info.getPackages().add(packageRollbackInfo);
@@ -334,6 +334,30 @@
     }
 
     /**
+     * Enables this rollback for the provided apk-in-apex.
+     *
+     * @return boolean True if the rollback was enabled successfully for the specified package.
+     */
+    boolean enableForPackageInApex(String packageName, long installedVersion,
+            int rollbackDataPolicy) {
+        // TODO(b/142712057): Extract the new version number of apk-in-apex
+        // The new version for the apk-in-apex is set to 0 for now. If the package is then further
+        // updated via non-staged install flow, then RollbackManagerServiceImpl#onPackageReplaced()
+        // will be called and this rollback will be deleted. Other ways of package update have not
+        // been handled yet.
+        PackageRollbackInfo packageRollbackInfo = new PackageRollbackInfo(
+                new VersionedPackage(packageName, 0 /* newVersion */),
+                new VersionedPackage(packageName, installedVersion),
+                new IntArray() /* pendingBackups */, new ArrayList<>() /* pendingRestores */,
+                false /* isApex */, true /* isApkInApex */, new IntArray(),
+                new SparseLongArray() /* ceSnapshotInodes */, rollbackDataPolicy);
+        synchronized (mLock) {
+            info.getPackages().add(packageRollbackInfo);
+        }
+        return true;
+    }
+
+    /**
      * Snapshots user data for the provided package and user ids. Does nothing if this rollback is
      * not in the ENABLING state.
      */
@@ -428,6 +452,11 @@
                         parentSessionId);
 
                 for (PackageRollbackInfo pkgRollbackInfo : info.getPackages()) {
+                    if (pkgRollbackInfo.isApkInApex()) {
+                        // No need to issue a downgrade install request for apk-in-apex. It will
+                        // be rolled back when its parent apex is downgraded.
+                        continue;
+                    }
                     PackageInstaller.SessionParams params = new PackageInstaller.SessionParams(
                             PackageInstaller.SessionParams.MODE_FULL_INSTALL);
                     String installerPackageName = mInstallerPackageName;
@@ -453,7 +482,8 @@
                             this, pkgRollbackInfo.getPackageName());
                     if (packageCodePaths == null) {
                         sendFailure(context, statusReceiver, RollbackManager.STATUS_FAILURE,
-                                "Backup copy of package inaccessible");
+                                "Backup copy of package: "
+                                        + pkgRollbackInfo.getPackageName() + " is inaccessible");
                         return;
                     }
 
diff --git a/services/core/java/com/android/server/rollback/RollbackManagerServiceImpl.java b/services/core/java/com/android/server/rollback/RollbackManagerServiceImpl.java
index e29d1a7..2a59af3 100644
--- a/services/core/java/com/android/server/rollback/RollbackManagerServiceImpl.java
+++ b/services/core/java/com/android/server/rollback/RollbackManagerServiceImpl.java
@@ -891,9 +891,36 @@
         }
 
         ApplicationInfo appInfo = pkgInfo.applicationInfo;
-        return rollback.enableForPackage(packageName, newPackage.versionCode,
+        boolean success = rollback.enableForPackage(packageName, newPackage.versionCode,
                 pkgInfo.getLongVersionCode(), isApex, appInfo.sourceDir,
                 appInfo.splitSourceDirs, session.rollbackDataPolicy);
+        if (!success) {
+            return success;
+        }
+
+        if (isApex) {
+            // Check if this apex contains apks inside it. If true, then they should be added as
+            // a RollbackPackageInfo into this rollback
+            final PackageManagerInternal pmi = LocalServices.getService(
+                    PackageManagerInternal.class);
+            List<String> apksInApex = pmi.getApksInApex(packageName);
+            for (String apkInApex : apksInApex) {
+                // Get information about the currently installed package.
+                final PackageInfo apkPkgInfo;
+                try {
+                    apkPkgInfo = getPackageInfo(apkInApex);
+                } catch (PackageManager.NameNotFoundException e) {
+                    // TODO: Support rolling back fresh package installs rather than
+                    // fail here. Test this case.
+                    Slog.e(TAG, apkInApex + " is not installed");
+                    return false;
+                }
+                success = rollback.enableForPackageInApex(
+                        apkInApex, apkPkgInfo.getLongVersionCode(), session.rollbackDataPolicy);
+                if (!success) return success;
+            }
+        }
+        return true;
     }
 
     @Override
@@ -907,9 +934,13 @@
         getHandler().post(() -> {
             snapshotUserDataInternal(packageName, userIds);
             restoreUserDataInternal(packageName, userIds, appId, seInfo);
-            final PackageManagerInternal pmi = LocalServices.getService(
-                    PackageManagerInternal.class);
-            pmi.finishPackageInstall(token, false);
+            // When this method is called as part of the install flow, a positive token number is
+            // passed to it. Need to notify the PackageManager when we are done.
+            if (token > 0) {
+                final PackageManagerInternal pmi = LocalServices.getService(
+                        PackageManagerInternal.class);
+                pmi.finishPackageInstall(token, false);
+            }
         });
     }
 
@@ -1195,11 +1226,12 @@
             return null;
         }
 
-        if (rollback.getPackageCount() != newRollback.getPackageSessionIdCount()) {
-            Slog.e(TAG, "Failed to enable rollback for all packages in session.");
-            rollback.delete(mAppDataRollbackHelper);
-            return null;
-        }
+        // TODO(b/142712057): Re-enable this check. For that we need count of apks-in-apex
+//        if (rollback.getPackageCount() != newRollback.getPackageSessionIdCount()) {
+//            Slog.e(TAG, "Failed to enable rollback for all packages in session.");
+//            rollback.delete(mAppDataRollbackHelper);
+//            return null;
+//        }
 
         rollback.saveRollback();
         synchronized (mLock) {
diff --git a/services/core/java/com/android/server/rollback/RollbackStore.java b/services/core/java/com/android/server/rollback/RollbackStore.java
index df75a29..bbcd0de 100644
--- a/services/core/java/com/android/server/rollback/RollbackStore.java
+++ b/services/core/java/com/android/server/rollback/RollbackStore.java
@@ -341,6 +341,7 @@
         json.put("pendingRestores", convertToJsonArray(pendingRestores));
 
         json.put("isApex", info.isApex());
+        json.put("isApkInApex", info.isApkInApex());
 
         // Field is named 'installedUsers' for legacy reasons.
         json.put("installedUsers", convertToJsonArray(snapshottedUsers));
@@ -364,6 +365,7 @@
                 json.getJSONArray("pendingRestores"));
 
         final boolean isApex = json.getBoolean("isApex");
+        final boolean isApkInApex = json.getBoolean("isApkInApex");
 
         // Field is named 'installedUsers' for legacy reasons.
         final IntArray snapshottedUsers = convertToIntArray(json.getJSONArray("installedUsers"));
@@ -375,8 +377,8 @@
                 PackageManager.RollbackDataPolicy.RESTORE);
 
         return new PackageRollbackInfo(versionRolledBackFrom, versionRolledBackTo,
-                pendingBackups, pendingRestores, isApex, snapshottedUsers, ceSnapshotInodes,
-                rollbackDataPolicy);
+                pendingBackups, pendingRestores, isApex, isApkInApex, snapshottedUsers,
+                ceSnapshotInodes, rollbackDataPolicy);
     }
 
     private static JSONArray versionedPackagesToJson(List<VersionedPackage> packages)
