vfio-pci: Use mutex around open, release, and remove

Serializing open/release allows us to fix a refcnt error if we fail
to enable the device and lets us prevent devices from being unbound
or opened, giving us an opportunity to do bus resets on release.  No
restriction added to serialize binding devices to vfio-pci while the
mutex is held though.

Signed-off-by: Alex Williamson <alex.williamson@redhat.com>
diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index fc011e1..c9d756b 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -37,6 +37,8 @@
 MODULE_PARM_DESC(nointxmask,
 		  "Disable support for PCI 2.3 style INTx masking.  If this resolves problems for specific devices, report lspci -vvvxxx to linux-pci@vger.kernel.org so the device can be fixed automatically via the broken_intx_masking flag.");
 
+static DEFINE_MUTEX(driver_lock);
+
 static int vfio_pci_enable(struct vfio_pci_device *vdev)
 {
 	struct pci_dev *pdev = vdev->pdev;
@@ -163,23 +165,29 @@
 {
 	struct vfio_pci_device *vdev = device_data;
 
-	if (atomic_dec_and_test(&vdev->refcnt)) {
+	mutex_lock(&driver_lock);
+
+	if (!(--vdev->refcnt)) {
 		vfio_spapr_pci_eeh_release(vdev->pdev);
 		vfio_pci_disable(vdev);
 	}
 
+	mutex_unlock(&driver_lock);
+
 	module_put(THIS_MODULE);
 }
 
 static int vfio_pci_open(void *device_data)
 {
 	struct vfio_pci_device *vdev = device_data;
-	int ret;
+	int ret = 0;
 
 	if (!try_module_get(THIS_MODULE))
 		return -ENODEV;
 
-	if (atomic_inc_return(&vdev->refcnt) == 1) {
+	mutex_lock(&driver_lock);
+
+	if (!vdev->refcnt) {
 		ret = vfio_pci_enable(vdev);
 		if (ret)
 			goto error;
@@ -190,10 +198,11 @@
 			goto error;
 		}
 	}
-
-	return 0;
+	vdev->refcnt++;
 error:
-	module_put(THIS_MODULE);
+	mutex_unlock(&driver_lock);
+	if (ret)
+		module_put(THIS_MODULE);
 	return ret;
 }
 
@@ -849,7 +858,6 @@
 	vdev->irq_type = VFIO_PCI_NUM_IRQS;
 	mutex_init(&vdev->igate);
 	spin_lock_init(&vdev->irqlock);
-	atomic_set(&vdev->refcnt, 0);
 
 	ret = vfio_add_group_dev(&pdev->dev, &vfio_pci_ops, vdev);
 	if (ret) {
@@ -864,12 +872,15 @@
 {
 	struct vfio_pci_device *vdev;
 
-	vdev = vfio_del_group_dev(&pdev->dev);
-	if (!vdev)
-		return;
+	mutex_lock(&driver_lock);
 
-	iommu_group_put(pdev->dev.iommu_group);
-	kfree(vdev);
+	vdev = vfio_del_group_dev(&pdev->dev);
+	if (vdev) {
+		iommu_group_put(pdev->dev.iommu_group);
+		kfree(vdev);
+	}
+
+	mutex_unlock(&driver_lock);
 }
 
 static pci_ers_result_t vfio_pci_aer_err_detected(struct pci_dev *pdev,