netlabel: Add an address family to domain hash entries.

The reason is to allow different labelling protocols for
different address families with the same domain.

This requires the addition of an address family attribute
in the netlink communication protocol.  It is used in several
messages:

NLBL_MGMT_C_ADD and NLBL_MGMT_C_ADDDEF take it as an optional
attribute for the unlabelled protocol.  It may be one of AF_INET,
AF_INET6 or AF_UNSPEC (to specify both address families).  If it
is missing, it defaults to AF_UNSPEC.

NLBL_MGMT_C_LISTALL and NLBL_MGMT_C_LISTDEF return it as part of
the enumeration of each item.  Addtionally, it may be sent to
LISTDEF to specify which address family to return.

Signed-off-by: Huw Davies <huw@codeweavers.com>
Signed-off-by: Paul Moore <paul@paul-moore.com>
diff --git a/net/netlabel/netlabel_domainhash.c b/net/netlabel/netlabel_domainhash.c
index d4d6640..3b3b304 100644
--- a/net/netlabel/netlabel_domainhash.c
+++ b/net/netlabel/netlabel_domainhash.c
@@ -56,7 +56,8 @@
 #define netlbl_domhsh_rcu_deref(p) \
 	rcu_dereference_check(p, lockdep_is_held(&netlbl_domhsh_lock))
 static struct netlbl_domhsh_tbl __rcu *netlbl_domhsh;
-static struct netlbl_dom_map __rcu *netlbl_domhsh_def;
+static struct netlbl_dom_map __rcu *netlbl_domhsh_def_ipv4;
+static struct netlbl_dom_map __rcu *netlbl_domhsh_def_ipv6;
 
 /*
  * Domain Hash Table Helper Functions
@@ -126,18 +127,26 @@
 	return val & (netlbl_domhsh_rcu_deref(netlbl_domhsh)->size - 1);
 }
 
+static bool netlbl_family_match(u16 f1, u16 f2)
+{
+	return (f1 == f2) || (f1 == AF_UNSPEC) || (f2 == AF_UNSPEC);
+}
+
 /**
  * netlbl_domhsh_search - Search for a domain entry
  * @domain: the domain
+ * @family: the address family
  *
  * Description:
  * Searches the domain hash table and returns a pointer to the hash table
- * entry if found, otherwise NULL is returned.  The caller is responsible for
+ * entry if found, otherwise NULL is returned.  @family may be %AF_UNSPEC
+ * which matches any address family entries.  The caller is responsible for
  * ensuring that the hash table is protected with either a RCU read lock or the
  * hash table lock.
  *
  */
-static struct netlbl_dom_map *netlbl_domhsh_search(const char *domain)
+static struct netlbl_dom_map *netlbl_domhsh_search(const char *domain,
+						   u16 family)
 {
 	u32 bkt;
 	struct list_head *bkt_list;
@@ -147,7 +156,9 @@
 		bkt = netlbl_domhsh_hash(domain);
 		bkt_list = &netlbl_domhsh_rcu_deref(netlbl_domhsh)->tbl[bkt];
 		list_for_each_entry_rcu(iter, bkt_list, list)
-			if (iter->valid && strcmp(iter->domain, domain) == 0)
+			if (iter->valid &&
+			    netlbl_family_match(iter->family, family) &&
+			    strcmp(iter->domain, domain) == 0)
 				return iter;
 	}
 
@@ -157,28 +168,37 @@
 /**
  * netlbl_domhsh_search_def - Search for a domain entry
  * @domain: the domain
- * @def: return default if no match is found
+ * @family: the address family
  *
  * Description:
  * Searches the domain hash table and returns a pointer to the hash table
  * entry if an exact match is found, if an exact match is not present in the
  * hash table then the default entry is returned if valid otherwise NULL is
- * returned.  The caller is responsible ensuring that the hash table is
+ * returned.  @family may be %AF_UNSPEC which matches any address family
+ * entries.  The caller is responsible ensuring that the hash table is
  * protected with either a RCU read lock or the hash table lock.
  *
  */
-static struct netlbl_dom_map *netlbl_domhsh_search_def(const char *domain)
+static struct netlbl_dom_map *netlbl_domhsh_search_def(const char *domain,
+						       u16 family)
 {
 	struct netlbl_dom_map *entry;
 
-	entry = netlbl_domhsh_search(domain);
-	if (entry == NULL) {
-		entry = netlbl_domhsh_rcu_deref(netlbl_domhsh_def);
-		if (entry != NULL && !entry->valid)
-			entry = NULL;
+	entry = netlbl_domhsh_search(domain, family);
+	if (entry != NULL)
+		return entry;
+	if (family == AF_INET || family == AF_UNSPEC) {
+		entry = netlbl_domhsh_rcu_deref(netlbl_domhsh_def_ipv4);
+		if (entry != NULL && entry->valid)
+			return entry;
+	}
+	if (family == AF_INET6 || family == AF_UNSPEC) {
+		entry = netlbl_domhsh_rcu_deref(netlbl_domhsh_def_ipv6);
+		if (entry != NULL && entry->valid)
+			return entry;
 	}
 
-	return entry;
+	return NULL;
 }
 
 /**
@@ -264,13 +284,19 @@
 	if (entry == NULL)
 		return -EINVAL;
 
+	if (entry->family != AF_INET && entry->family != AF_INET6 &&
+	    (entry->family != AF_UNSPEC ||
+	     entry->def.type != NETLBL_NLTYPE_UNLABELED))
+		return -EINVAL;
+
 	switch (entry->def.type) {
 	case NETLBL_NLTYPE_UNLABELED:
 		if (entry->def.cipso != NULL || entry->def.addrsel != NULL)
 			return -EINVAL;
 		break;
 	case NETLBL_NLTYPE_CIPSOV4:
-		if (entry->def.cipso == NULL)
+		if (entry->family != AF_INET ||
+		    entry->def.cipso == NULL)
 			return -EINVAL;
 		break;
 	case NETLBL_NLTYPE_ADDRSELECT:
@@ -358,15 +384,18 @@
  *
  * Description:
  * Adds a new entry to the domain hash table and handles any updates to the
- * lower level protocol handler (i.e. CIPSO).  Returns zero on success,
- * negative on failure.
+ * lower level protocol handler (i.e. CIPSO).  @entry->family may be set to
+ * %AF_UNSPEC which will add an entry that matches all address families.  This
+ * is only useful for the unlabelled type and will only succeed if there is no
+ * existing entry for any address family with the same domain.  Returns zero
+ * on success, negative on failure.
  *
  */
 int netlbl_domhsh_add(struct netlbl_dom_map *entry,
 		      struct netlbl_audit *audit_info)
 {
 	int ret_val = 0;
-	struct netlbl_dom_map *entry_old;
+	struct netlbl_dom_map *entry_old, *entry_b;
 	struct netlbl_af4list *iter4;
 	struct netlbl_af4list *tmp4;
 #if IS_ENABLED(CONFIG_IPV6)
@@ -385,9 +414,10 @@
 	rcu_read_lock();
 	spin_lock(&netlbl_domhsh_lock);
 	if (entry->domain != NULL)
-		entry_old = netlbl_domhsh_search(entry->domain);
+		entry_old = netlbl_domhsh_search(entry->domain, entry->family);
 	else
-		entry_old = netlbl_domhsh_search_def(entry->domain);
+		entry_old = netlbl_domhsh_search_def(entry->domain,
+						     entry->family);
 	if (entry_old == NULL) {
 		entry->valid = 1;
 
@@ -397,7 +427,41 @@
 				    &rcu_dereference(netlbl_domhsh)->tbl[bkt]);
 		} else {
 			INIT_LIST_HEAD(&entry->list);
-			rcu_assign_pointer(netlbl_domhsh_def, entry);
+			switch (entry->family) {
+			case AF_INET:
+				rcu_assign_pointer(netlbl_domhsh_def_ipv4,
+						   entry);
+				break;
+			case AF_INET6:
+				rcu_assign_pointer(netlbl_domhsh_def_ipv6,
+						   entry);
+				break;
+			case AF_UNSPEC:
+				if (entry->def.type !=
+				    NETLBL_NLTYPE_UNLABELED) {
+					ret_val = -EINVAL;
+					goto add_return;
+				}
+				entry_b = kzalloc(sizeof(*entry_b), GFP_ATOMIC);
+				if (entry_b == NULL) {
+					ret_val = -ENOMEM;
+					goto add_return;
+				}
+				entry_b->family = AF_INET6;
+				entry_b->def.type = NETLBL_NLTYPE_UNLABELED;
+				entry_b->valid = 1;
+				entry->family = AF_INET;
+				rcu_assign_pointer(netlbl_domhsh_def_ipv4,
+						   entry);
+				rcu_assign_pointer(netlbl_domhsh_def_ipv6,
+						   entry_b);
+				break;
+			default:
+				/* Already checked in
+				 * netlbl_domhsh_validate(). */
+				ret_val = -EINVAL;
+				goto add_return;
+			}
 		}
 
 		if (entry->def.type == NETLBL_NLTYPE_ADDRSELECT) {
@@ -513,10 +577,12 @@
 	spin_lock(&netlbl_domhsh_lock);
 	if (entry->valid) {
 		entry->valid = 0;
-		if (entry != rcu_dereference(netlbl_domhsh_def))
-			list_del_rcu(&entry->list);
+		if (entry == rcu_dereference(netlbl_domhsh_def_ipv4))
+			RCU_INIT_POINTER(netlbl_domhsh_def_ipv4, NULL);
+		else if (entry == rcu_dereference(netlbl_domhsh_def_ipv6))
+			RCU_INIT_POINTER(netlbl_domhsh_def_ipv6, NULL);
 		else
-			RCU_INIT_POINTER(netlbl_domhsh_def, NULL);
+			list_del_rcu(&entry->list);
 	} else
 		ret_val = -ENOENT;
 	spin_unlock(&netlbl_domhsh_lock);
@@ -583,9 +649,9 @@
 	rcu_read_lock();
 
 	if (domain)
-		entry_map = netlbl_domhsh_search(domain);
+		entry_map = netlbl_domhsh_search(domain, AF_INET);
 	else
-		entry_map = netlbl_domhsh_search_def(domain);
+		entry_map = netlbl_domhsh_search_def(domain, AF_INET);
 	if (entry_map == NULL ||
 	    entry_map->def.type != NETLBL_NLTYPE_ADDRSELECT)
 		goto remove_af4_failure;
@@ -625,25 +691,45 @@
 /**
  * netlbl_domhsh_remove - Removes an entry from the domain hash table
  * @domain: the domain to remove
+ * @family: address family
  * @audit_info: NetLabel audit information
  *
  * Description:
  * Removes an entry from the domain hash table and handles any updates to the
- * lower level protocol handler (i.e. CIPSO).  Returns zero on success,
- * negative on failure.
+ * lower level protocol handler (i.e. CIPSO).  @family may be %AF_UNSPEC which
+ * removes all address family entries.  Returns zero on success, negative on
+ * failure.
  *
  */
-int netlbl_domhsh_remove(const char *domain, struct netlbl_audit *audit_info)
+int netlbl_domhsh_remove(const char *domain, u16 family,
+			 struct netlbl_audit *audit_info)
 {
-	int ret_val;
+	int ret_val = -EINVAL;
 	struct netlbl_dom_map *entry;
 
 	rcu_read_lock();
-	if (domain)
-		entry = netlbl_domhsh_search(domain);
-	else
-		entry = netlbl_domhsh_search_def(domain);
-	ret_val = netlbl_domhsh_remove_entry(entry, audit_info);
+
+	if (family == AF_INET || family == AF_UNSPEC) {
+		if (domain)
+			entry = netlbl_domhsh_search(domain, AF_INET);
+		else
+			entry = netlbl_domhsh_search_def(domain, AF_INET);
+		ret_val = netlbl_domhsh_remove_entry(entry, audit_info);
+		if (ret_val && ret_val != -ENOENT)
+			goto done;
+	}
+	if (family == AF_INET6 || family == AF_UNSPEC) {
+		int ret_val2;
+
+		if (domain)
+			entry = netlbl_domhsh_search(domain, AF_INET6);
+		else
+			entry = netlbl_domhsh_search_def(domain, AF_INET6);
+		ret_val2 = netlbl_domhsh_remove_entry(entry, audit_info);
+		if (ret_val2 != -ENOENT)
+			ret_val = ret_val2;
+	}
+done:
 	rcu_read_unlock();
 
 	return ret_val;
@@ -651,32 +737,38 @@
 
 /**
  * netlbl_domhsh_remove_default - Removes the default entry from the table
+ * @family: address family
  * @audit_info: NetLabel audit information
  *
  * Description:
- * Removes/resets the default entry for the domain hash table and handles any
- * updates to the lower level protocol handler (i.e. CIPSO).  Returns zero on
- * success, non-zero on failure.
+ * Removes/resets the default entry corresponding to @family from the domain
+ * hash table and handles any updates to the lower level protocol handler
+ * (i.e. CIPSO).  @family may be %AF_UNSPEC which removes all address family
+ * entries.  Returns zero on success, negative on failure.
  *
  */
-int netlbl_domhsh_remove_default(struct netlbl_audit *audit_info)
+int netlbl_domhsh_remove_default(u16 family, struct netlbl_audit *audit_info)
 {
-	return netlbl_domhsh_remove(NULL, audit_info);
+	return netlbl_domhsh_remove(NULL, family, audit_info);
 }
 
 /**
  * netlbl_domhsh_getentry - Get an entry from the domain hash table
  * @domain: the domain name to search for
+ * @family: address family
  *
  * Description:
  * Look through the domain hash table searching for an entry to match @domain,
- * return a pointer to a copy of the entry or NULL.  The caller is responsible
- * for ensuring that rcu_read_[un]lock() is called.
+ * with address family @family, return a pointer to a copy of the entry or
+ * NULL.  The caller is responsible for ensuring that rcu_read_[un]lock() is
+ * called.
  *
  */
-struct netlbl_dom_map *netlbl_domhsh_getentry(const char *domain)
+struct netlbl_dom_map *netlbl_domhsh_getentry(const char *domain, u16 family)
 {
-	return netlbl_domhsh_search_def(domain);
+	if (family == AF_UNSPEC)
+		return NULL;
+	return netlbl_domhsh_search_def(domain, family);
 }
 
 /**
@@ -696,7 +788,7 @@
 	struct netlbl_dom_map *dom_iter;
 	struct netlbl_af4list *addr_iter;
 
-	dom_iter = netlbl_domhsh_search_def(domain);
+	dom_iter = netlbl_domhsh_search_def(domain, AF_INET);
 	if (dom_iter == NULL)
 		return NULL;
 
@@ -726,7 +818,7 @@
 	struct netlbl_dom_map *dom_iter;
 	struct netlbl_af6list *addr_iter;
 
-	dom_iter = netlbl_domhsh_search_def(domain);
+	dom_iter = netlbl_domhsh_search_def(domain, AF_INET6);
 	if (dom_iter == NULL)
 		return NULL;