iommu: Limit iommu_attach/detach_device to devices with their own group

This patch changes the behavior of the iommu_attach_device
and iommu_detach_device functions. With this change these
functions only work on devices that have their own group.
For all other devices the iommu_group_attach/detach
functions must be used.

Signed-off-by: Joerg Roedel <jroedel@suse.de>
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 49eb9bf..ef73923 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -433,6 +433,17 @@
 }
 EXPORT_SYMBOL_GPL(iommu_group_remove_device);
 
+static int iommu_group_device_count(struct iommu_group *group)
+{
+	struct iommu_device *entry;
+	int ret = 0;
+
+	list_for_each_entry(entry, &group->devices, list)
+		ret++;
+
+	return ret;
+}
+
 /**
  * iommu_group_for_each_dev - iterate over each device in the group
  * @group: the group
@@ -969,7 +980,8 @@
 }
 EXPORT_SYMBOL_GPL(iommu_domain_free);
 
-int iommu_attach_device(struct iommu_domain *domain, struct device *dev)
+static int __iommu_attach_device(struct iommu_domain *domain,
+				 struct device *dev)
 {
 	int ret;
 	if (unlikely(domain->ops->attach_dev == NULL))
@@ -980,9 +992,38 @@
 		trace_attach_device_to_domain(dev);
 	return ret;
 }
+
+int iommu_attach_device(struct iommu_domain *domain, struct device *dev)
+{
+	struct iommu_group *group;
+	int ret;
+
+	group = iommu_group_get(dev);
+	/* FIXME: Remove this when groups a mandatory for iommu drivers */
+	if (group == NULL)
+		return __iommu_attach_device(domain, dev);
+
+	/*
+	 * We have a group - lock it to make sure the device-count doesn't
+	 * change while we are attaching
+	 */
+	mutex_lock(&group->mutex);
+	ret = -EINVAL;
+	if (iommu_group_device_count(group) != 1)
+		goto out_unlock;
+
+	ret = __iommu_attach_device(domain, dev);
+
+out_unlock:
+	mutex_unlock(&group->mutex);
+	iommu_group_put(group);
+
+	return ret;
+}
 EXPORT_SYMBOL_GPL(iommu_attach_device);
 
-void iommu_detach_device(struct iommu_domain *domain, struct device *dev)
+static void __iommu_detach_device(struct iommu_domain *domain,
+				  struct device *dev)
 {
 	if (unlikely(domain->ops->detach_dev == NULL))
 		return;
@@ -990,6 +1031,28 @@
 	domain->ops->detach_dev(domain, dev);
 	trace_detach_device_from_domain(dev);
 }
+
+void iommu_detach_device(struct iommu_domain *domain, struct device *dev)
+{
+	struct iommu_group *group;
+
+	group = iommu_group_get(dev);
+	/* FIXME: Remove this when groups a mandatory for iommu drivers */
+	if (group == NULL)
+		return __iommu_detach_device(domain, dev);
+
+	mutex_lock(&group->mutex);
+	if (iommu_group_device_count(group) != 1) {
+		WARN_ON(1);
+		goto out_unlock;
+	}
+
+	__iommu_detach_device(domain, dev);
+
+out_unlock:
+	mutex_unlock(&group->mutex);
+	iommu_group_put(group);
+}
 EXPORT_SYMBOL_GPL(iommu_detach_device);
 
 /*
@@ -1006,7 +1069,7 @@
 {
 	struct iommu_domain *domain = data;
 
-	return iommu_attach_device(domain, dev);
+	return __iommu_attach_device(domain, dev);
 }
 
 int iommu_attach_group(struct iommu_domain *domain, struct iommu_group *group)
@@ -1020,7 +1083,7 @@
 {
 	struct iommu_domain *domain = data;
 
-	iommu_detach_device(domain, dev);
+	__iommu_detach_device(domain, dev);
 
 	return 0;
 }