iommu/vt-d: Fix mm refcounting to hold mm_count not mm_users

Holding mm_users works OK for graphics, which was the first user of SVM
with VT-d. However, it works less well for other devices, where we actually
do a mmap() from the file descriptor to which the SVM PASID state is tied.

In this case on process exit we end up with a recursive reference count:
 - The MM remains alive until the file is closed and the driver's release()
   call ends up unbinding the PASID.
 - The VMA corresponding to the mmap() remains intact until the MM is
   destroyed.
 - Thus the file isn't closed, even when exit_files() runs, because the
   VMA is still holding a reference to it. And the MM remains alive…

To address this issue, we *stop* holding mm_users while the PASID is bound.
We already hold mm_count by virtue of the MMU notifier, and that can be
made to be sufficient.

It means that for a period during process exit, the fun part of mmput()
has happened and exit_mmap() has been called so the MM is basically
defunct. But the PGD still exists and the PASID is still bound to it.

During this period, we have to be very careful — exit_mmap() doesn't use
mm->mmap_sem because it doesn't expect anyone else to be touching the MM
(quite reasonably, since mm_users is zero). So we also need to fix the
fault handler to just report failure if mm_users is already zero, and to
temporarily bump mm_users while handling any faults.

Additionally, exit_mmap() calls mmu_notifier_release() *before* it tears
down the page tables, which is too early for us to flush the IOTLB for
this PASID. And __mmu_notifier_release() removes every notifier from the
list, so when exit_mmap() finally *does* tear down the mappings and
clear the page tables, we don't get notified. So we work around this by
clearing the PASID table entry in our MMU notifier release() callback.
That way, the hardware *can't* get any pages back from the page tables
before they get cleared.

Hardware designers have confirmed that the resulting 'PASID not present'
faults should be handled just as gracefully as 'page not present' faults,
the important criterion being that they don't perturb the operation for
any *other* PASID in the system.

Signed-off-by: David Woodhouse <David.Woodhouse@intel.com>
Cc: stable@vger.kernel.org
diff --git a/drivers/iommu/intel-svm.c b/drivers/iommu/intel-svm.c
index 5046483..97a8189 100644
--- a/drivers/iommu/intel-svm.c
+++ b/drivers/iommu/intel-svm.c
@@ -249,12 +249,30 @@
 static void intel_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 {
 	struct intel_svm *svm = container_of(mn, struct intel_svm, notifier);
+	struct intel_svm_dev *sdev;
 
+	/* This might end up being called from exit_mmap(), *before* the page
+	 * tables are cleared. And __mmu_notifier_release() will delete us from
+	 * the list of notifiers so that our invalidate_range() callback doesn't
+	 * get called when the page tables are cleared. So we need to protect
+	 * against hardware accessing those page tables.
+	 *
+	 * We do it by clearing the entry in the PASID table and then flushing
+	 * the IOTLB and the PASID table caches. This might upset hardware;
+	 * perhaps we'll want to point the PASID to a dummy PGD (like the zero
+	 * page) so that we end up taking a fault that the hardware really
+	 * *has* to handle gracefully without affecting other processes.
+	 */
 	svm->iommu->pasid_table[svm->pasid].val = 0;
+	wmb();
 
-	/* There's no need to do any flush because we can't get here if there
-	 * are any devices left anyway. */
-	WARN_ON(!list_empty(&svm->devs));
+	rcu_read_lock();
+	list_for_each_entry_rcu(sdev, &svm->devs, list) {
+		intel_flush_pasid_dev(svm, sdev, svm->pasid);
+		intel_flush_svm_range_dev(svm, sdev, 0, -1, 0, !svm->mm);
+	}
+	rcu_read_unlock();
+
 }
 
 static const struct mmu_notifier_ops intel_mmuops = {
@@ -379,7 +397,6 @@
 				goto out;
 			}
 			iommu->pasid_table[svm->pasid].val = (u64)__pa(mm->pgd) | 1;
-			mm = NULL;
 		} else
 			iommu->pasid_table[svm->pasid].val = (u64)__pa(init_mm.pgd) | 1 | (1ULL << 11);
 		wmb();
@@ -442,11 +459,11 @@
 				kfree_rcu(sdev, rcu);
 
 				if (list_empty(&svm->devs)) {
-					mmu_notifier_unregister(&svm->notifier, svm->mm);
 
 					idr_remove(&svm->iommu->pasid_idr, svm->pasid);
 					if (svm->mm)
-						mmput(svm->mm);
+						mmu_notifier_unregister(&svm->notifier, svm->mm);
+
 					/* We mandate that no page faults may be outstanding
 					 * for the PASID when intel_svm_unbind_mm() is called.
 					 * If that is not obeyed, subtle errors will happen.
@@ -551,6 +568,9 @@
 		 * any faults on kernel addresses. */
 		if (!svm->mm)
 			goto bad_req;
+		/* If the mm is already defunct, don't handle faults. */
+		if (!atomic_inc_not_zero(&svm->mm->mm_users))
+			goto bad_req;
 		down_read(&svm->mm->mmap_sem);
 		vma = find_extend_vma(svm->mm, address);
 		if (!vma || address < vma->vm_start)
@@ -567,6 +587,7 @@
 		result = QI_RESP_SUCCESS;
 	invalid:
 		up_read(&svm->mm->mmap_sem);
+		mmput(svm->mm);
 	bad_req:
 		/* Accounting for major/minor faults? */
 		rcu_read_lock();