iommu: Make sure a device is always attached to a domain

Make use of the default domain and re-attach a device to it
when it is detached from another domain. Also enforce that a
device has to be in the default domain before it can be
attached to a different domain.

Signed-off-by: Joerg Roedel <jroedel@suse.de>
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index ef73923..7bce522 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -52,6 +52,7 @@
 	char *name;
 	int id;
 	struct iommu_domain *default_domain;
+	struct iommu_domain *domain;
 };
 
 struct iommu_device {
@@ -78,6 +79,12 @@
 
 static struct iommu_domain *__iommu_domain_alloc(struct bus_type *bus,
 						 unsigned type);
+static int __iommu_attach_device(struct iommu_domain *domain,
+				 struct device *dev);
+static int __iommu_attach_group(struct iommu_domain *domain,
+				struct iommu_group *group);
+static void __iommu_detach_group(struct iommu_domain *domain,
+				 struct iommu_group *group);
 
 static ssize_t iommu_group_attr_show(struct kobject *kobj,
 				     struct attribute *__attr, char *buf)
@@ -376,6 +383,8 @@
 
 	mutex_lock(&group->mutex);
 	list_add_tail(&device->list, &group->devices);
+	if (group->domain)
+		__iommu_attach_device(group->domain, dev);
 	mutex_unlock(&group->mutex);
 
 	/* Notify any listeners about change to group. */
@@ -455,19 +464,30 @@
  * The group->mutex is held across callbacks, which will block calls to
  * iommu_group_add/remove_device.
  */
-int iommu_group_for_each_dev(struct iommu_group *group, void *data,
-			     int (*fn)(struct device *, void *))
+static int __iommu_group_for_each_dev(struct iommu_group *group, void *data,
+				      int (*fn)(struct device *, void *))
 {
 	struct iommu_device *device;
 	int ret = 0;
 
-	mutex_lock(&group->mutex);
 	list_for_each_entry(device, &group->devices, list) {
 		ret = fn(device->dev, data);
 		if (ret)
 			break;
 	}
+	return ret;
+}
+
+
+int iommu_group_for_each_dev(struct iommu_group *group, void *data,
+			     int (*fn)(struct device *, void *))
+{
+	int ret;
+
+	mutex_lock(&group->mutex);
+	ret = __iommu_group_for_each_dev(group, data, fn);
 	mutex_unlock(&group->mutex);
+
 	return ret;
 }
 EXPORT_SYMBOL_GPL(iommu_group_for_each_dev);
@@ -727,6 +747,7 @@
 		 */
 		group->default_domain = __iommu_domain_alloc(pdev->dev.bus,
 							     IOMMU_DOMAIN_DMA);
+		group->domain = group->default_domain;
 	}
 
 	return group;
@@ -1012,7 +1033,7 @@
 	if (iommu_group_device_count(group) != 1)
 		goto out_unlock;
 
-	ret = __iommu_attach_device(domain, dev);
+	ret = __iommu_attach_group(domain, group);
 
 out_unlock:
 	mutex_unlock(&group->mutex);
@@ -1047,7 +1068,7 @@
 		goto out_unlock;
 	}
 
-	__iommu_detach_device(domain, dev);
+	__iommu_detach_group(domain, group);
 
 out_unlock:
 	mutex_unlock(&group->mutex);
@@ -1072,10 +1093,31 @@
 	return __iommu_attach_device(domain, dev);
 }
 
+static int __iommu_attach_group(struct iommu_domain *domain,
+				struct iommu_group *group)
+{
+	int ret;
+
+	if (group->default_domain && group->domain != group->default_domain)
+		return -EBUSY;
+
+	ret = __iommu_group_for_each_dev(group, domain,
+					 iommu_group_do_attach_device);
+	if (ret == 0)
+		group->domain = domain;
+
+	return ret;
+}
+
 int iommu_attach_group(struct iommu_domain *domain, struct iommu_group *group)
 {
-	return iommu_group_for_each_dev(group, domain,
-					iommu_group_do_attach_device);
+	int ret;
+
+	mutex_lock(&group->mutex);
+	ret = __iommu_attach_group(domain, group);
+	mutex_unlock(&group->mutex);
+
+	return ret;
 }
 EXPORT_SYMBOL_GPL(iommu_attach_group);
 
@@ -1088,9 +1130,35 @@
 	return 0;
 }
 
+static void __iommu_detach_group(struct iommu_domain *domain,
+				 struct iommu_group *group)
+{
+	int ret;
+
+	if (!group->default_domain) {
+		__iommu_group_for_each_dev(group, domain,
+					   iommu_group_do_detach_device);
+		group->domain = NULL;
+		return;
+	}
+
+	if (group->domain == group->default_domain)
+		return;
+
+	/* Detach by re-attaching to the default domain */
+	ret = __iommu_group_for_each_dev(group, group->default_domain,
+					 iommu_group_do_attach_device);
+	if (ret != 0)
+		WARN_ON(1);
+	else
+		group->domain = group->default_domain;
+}
+
 void iommu_detach_group(struct iommu_domain *domain, struct iommu_group *group)
 {
-	iommu_group_for_each_dev(group, domain, iommu_group_do_detach_device);
+	mutex_lock(&group->mutex);
+	__iommu_detach_group(domain, group);
+	mutex_unlock(&group->mutex);
 }
 EXPORT_SYMBOL_GPL(iommu_detach_group);