iommu/amd: Add invalid_ppr callback

This callback can be used to change the PRI response code
sent to a device when a PPR fault fails.

Signed-off-by: Joerg Roedel <joerg.roedel@amd.com>
diff --git a/drivers/iommu/amd_iommu_v2.c b/drivers/iommu/amd_iommu_v2.c
index abdb839..fe812e2 100644
--- a/drivers/iommu/amd_iommu_v2.c
+++ b/drivers/iommu/amd_iommu_v2.c
@@ -62,6 +62,7 @@
 	struct iommu_domain *domain;
 	int pasid_levels;
 	int max_pasids;
+	amd_iommu_invalid_ppr_cb inv_ppr_cb;
 	spinlock_t lock;
 	wait_queue_head_t wq;
 };
@@ -505,10 +506,31 @@
 	npages = get_user_pages(fault->state->task, fault->state->mm,
 				fault->address, 1, write, 0, &page, NULL);
 
-	if (npages == 1)
+	if (npages == 1) {
 		put_page(page);
-	else
+	} else if (fault->dev_state->inv_ppr_cb) {
+		int status;
+
+		status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
+						      fault->pasid,
+						      fault->address,
+						      fault->flags);
+		switch (status) {
+		case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
+			set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
+			break;
+		case AMD_IOMMU_INV_PRI_RSP_INVALID:
+			set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
+			break;
+		case AMD_IOMMU_INV_PRI_RSP_FAIL:
+			set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
+			break;
+		default:
+			BUG();
+		}
+	} else {
 		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
+	}
 
 	finish_pri_tag(fault->dev_state, fault->state, fault->tag);
 
@@ -828,6 +850,37 @@
 }
 EXPORT_SYMBOL(amd_iommu_free_device);
 
+int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
+				 amd_iommu_invalid_ppr_cb cb)
+{
+	struct device_state *dev_state;
+	unsigned long flags;
+	u16 devid;
+	int ret;
+
+	if (!amd_iommu_v2_supported())
+		return -ENODEV;
+
+	devid = device_id(pdev);
+
+	spin_lock_irqsave(&state_lock, flags);
+
+	ret = -EINVAL;
+	dev_state = state_table[devid];
+	if (dev_state == NULL)
+		goto out_unlock;
+
+	dev_state->inv_ppr_cb = cb;
+
+	ret = 0;
+
+out_unlock:
+	spin_unlock_irqrestore(&state_lock, flags);
+
+	return ret;
+}
+EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
+
 static int __init amd_iommu_v2_init(void)
 {
 	size_t state_table_size;