[sheepdog] [RFC PATCH 1/3] make rbtree more typesafe

MORITA Kazutaka morita.kazutaka at gmail.com
Wed Sep 25 20:12:09 CEST 2013


From: MORITA Kazutaka <morita.kazutaka at lab.ntt.co.jp>

This patch adds more strict type checking to rbtree functions and
forbids inserting invalid rb nodes into the rb root.

Signed-off-by: MORITA Kazutaka <morita.kazutaka at lab.ntt.co.jp>
---
 dog/dog.c                     |   6 +-
 dog/dog.h                     |   4 +-
 dog/farm/farm.c               |  21 +++----
 dog/farm/object_tree.c        |  32 ++++------
 dog/node.c                    |  14 ++---
 dog/trace.c                   |   7 ++-
 dog/vdi.c                     |  10 +--
 include/compiler.h            |   5 ++
 include/internal_proto.h      |   2 +
 include/rbtree.h              | 142 +++++++++++++++++++++++++++---------------
 include/sheep.h               |  48 +++++++-------
 include/sockfd_cache.h        |   2 +-
 lib/event.c                   |  10 +--
 lib/rbtree.c                  |  22 +++----
 lib/sockfd_cache.c            |  34 +++++-----
 sheep/cluster.h               |   8 +--
 sheep/cluster/corosync.c      |   6 +-
 sheep/cluster/local.c         |   8 +--
 sheep/cluster/zookeeper.c     |  28 +++++----
 sheep/group.c                 |  75 +++++++++++-----------
 sheep/md.c                    |  52 ++++++++--------
 sheep/object_cache.c          |  15 ++---
 sheep/object_list_cache.c     |  37 +++++------
 sheep/ops.c                   |   7 +--
 sheep/recovery.c              |   2 +-
 sheep/sheep_priv.h            |   2 +-
 sheep/vdi.c                   |  21 ++++---
 sheepfs/volume.c              |  17 ++---
 tests/unit/dog/mock_dog.c     |   4 +-
 tests/unit/mock/mock.c        |   6 +-
 tests/unit/mock/mock.h        |   6 +-
 tests/unit/sheep/mock_group.c |   6 +-
 tests/unit/sheep/test_hash.c  |  41 ++++++------
 33 files changed, 377 insertions(+), 323 deletions(-)

diff --git a/dog/dog.c b/dog/dog.c
index 16298ad..832cea8 100644
--- a/dog/dog.c
+++ b/dog/dog.c
@@ -49,8 +49,8 @@ static void usage(const struct command *commands, int status);
 uint32_t sd_epoch;
 
 int sd_nodes_nr;
-struct rb_root sd_vroot = RB_ROOT;
-struct rb_root sd_nroot = RB_ROOT;
+struct rb_vnode_root sd_vroot = RB_ROOT_INITIALIZER(vnode_cmp);
+struct rb_node_root sd_nroot = RB_ROOT_INITIALIZER(node_cmp);
 
 int update_node_list(int max_nodes)
 {
@@ -95,7 +95,7 @@ int update_node_list(int max_nodes)
 		struct sd_node *n = xmalloc(sizeof(*n));
 
 		*n = buf[i];
-		rb_insert(&sd_nroot, n, rb, node_cmp);
+		rb_insert(n, &sd_nroot);
 	}
 
 	nodes_to_vnodes(&sd_nroot, &sd_vroot);
diff --git a/dog/dog.h b/dog/dog.h
index 8c54c10..9d59b96 100644
--- a/dog/dog.h
+++ b/dog/dog.h
@@ -55,8 +55,8 @@ extern bool raw_output;
 extern bool verbose;
 
 extern uint32_t sd_epoch;
-extern struct rb_root sd_vroot;
-extern struct rb_root sd_nroot;
+extern struct rb_vnode_root sd_vroot;
+extern struct rb_node_root sd_nroot;
 extern int sd_nodes_nr;
 
 bool is_current(const struct sd_inode *i);
diff --git a/dog/farm/farm.c b/dog/farm/farm.c
index 0204d1a..ff0e74e 100644
--- a/dog/farm/farm.c
+++ b/dog/farm/farm.c
@@ -30,7 +30,6 @@ struct vdi_entry {
 	uint8_t  nr_copies;
 	struct rb_node rb;
 };
-static struct rb_root last_vdi_tree = RB_ROOT;
 
 struct snapshot_work {
 	struct trunk_entry entry;
@@ -45,13 +44,16 @@ static int vdi_cmp(const struct vdi_entry *e1, const struct vdi_entry *e2)
 	return strcmp(e1->name, e2->name);
 }
 
+static RB_ROOT(, struct vdi_entry, rb) last_vdi_tree =
+	RB_ROOT_INITIALIZER(vdi_cmp);
+
 static struct vdi_entry *find_vdi(const char *name)
 {
 	struct vdi_entry key = {};
 
 	pstrcpy(key.name, sizeof(key.name), name);
 
-	return rb_search(&last_vdi_tree, &key, rb, vdi_cmp);
+	return rb_search(&key, &last_vdi_tree);
 }
 
 static struct vdi_entry *new_vdi(const char *name, uint64_t vdi_size,
@@ -78,7 +80,7 @@ static void insert_vdi(struct sd_inode *new)
 			      new->vdi_id,
 			      new->snap_id,
 			      new->nr_copies);
-		rb_insert(&last_vdi_tree, vdi, rb, vdi_cmp);
+		rb_insert(vdi, &last_vdi_tree);
 	} else if (vdi->snap_id < new->snap_id) {
 		vdi->vdi_size = new->vdi_size;
 		vdi->vdi_id = new->vdi_id;
@@ -91,7 +93,7 @@ static int create_active_vdis(void)
 {
 	struct vdi_entry *vdi;
 	uint32_t new_vid;
-	rb_for_each_entry(vdi, &last_vdi_tree, rb) {
+	rb_for_each_entry(vdi, &last_vdi_tree) {
 		if (do_vdi_create(vdi->name,
 				  vdi->vdi_size,
 				  vdi->vdi_id, &new_vid,
@@ -101,15 +103,6 @@ static int create_active_vdis(void)
 	return 0;
 }
 
-static void free_vdi_list(void)
-{
-	struct vdi_entry *vdi;
-	rb_for_each_entry(vdi, &last_vdi_tree, rb) {
-		rb_erase(&vdi->rb, &last_vdi_tree);
-		free(vdi);
-	}
-}
-
 char *get_object_directory(void)
 {
 	return farm_object_dir;
@@ -414,6 +407,6 @@ int farm_load_snapshot(uint32_t idx, const char *tag)
 
 	ret = 0;
 out:
-	free_vdi_list();
+	rb_destroy(&last_vdi_tree);
 	return ret;
 }
diff --git a/dog/farm/object_tree.c b/dog/farm/object_tree.c
index c624fea..0485eaf 100644
--- a/dog/farm/object_tree.c
+++ b/dog/farm/object_tree.c
@@ -20,15 +20,6 @@ struct object_tree_entry {
 	struct rb_node node;
 };
 
-struct object_tree {
-	int nr_objs;
-	struct rb_root root;
-};
-
-static struct object_tree tree = {
-	.nr_objs = 0,
-	.root = RB_ROOT,
-};
 static struct object_tree_entry *cached_entry;
 
 static int object_tree_cmp(const struct object_tree_entry *a,
@@ -37,15 +28,18 @@ static int object_tree_cmp(const struct object_tree_entry *a,
 	return intcmp(a->oid, b->oid);
 }
 
-static struct object_tree_entry *do_insert(struct rb_root *root,
-				      struct object_tree_entry *new)
-{
-	return rb_insert(root, new, node, object_tree_cmp);
-}
+struct object_tree {
+	int nr_objs;
+	RB_ROOT(, struct object_tree_entry, node) root;
+};
+
+static struct object_tree tree = {
+	.nr_objs = 0,
+	.root = RB_ROOT_INITIALIZER(object_tree_cmp),
+};
 
 void object_tree_insert(uint64_t oid, int nr_copies)
 {
-	struct rb_root *root = &tree.root;
 	struct object_tree_entry *p = NULL;
 
 	if (!cached_entry)
@@ -53,7 +47,7 @@ void object_tree_insert(uint64_t oid, int nr_copies)
 	cached_entry->oid = oid;
 	cached_entry->nr_copies = nr_copies;
 	rb_init_node(&cached_entry->node);
-	p = do_insert(root, cached_entry);
+	p = rb_insert(cached_entry, &tree.root);
 	if (!p) {
 		tree.nr_objs++;
 		cached_entry = NULL;
@@ -65,13 +59,13 @@ void object_tree_print(void)
 	struct object_tree_entry *entry;
 	printf("nr_objs: %d\n", tree.nr_objs);
 
-	rb_for_each_entry(entry, &tree.root, node)
+	rb_for_each_entry(entry, &tree.root)
 		printf("Obj id: %"PRIx64"\n", entry->oid);
 }
 
 void object_tree_free(void)
 {
-	rb_destroy(&tree.root, struct object_tree_entry, node);
+	rb_destroy(&tree.root);
 	free(cached_entry);
 }
 
@@ -86,7 +80,7 @@ int for_each_object_in_tree(int (*func)(uint64_t oid, int nr_copies,
 	struct object_tree_entry *entry;
 	int ret = -1;
 
-	rb_for_each_entry(entry, &tree.root, node) {
+	rb_for_each_entry(entry, &tree.root) {
 		if (func(entry->oid, entry->nr_copies, data) < 0)
 			goto out;
 	}
diff --git a/dog/node.c b/dog/node.c
index 052739c..0d2d4a5 100644
--- a/dog/node.c
+++ b/dog/node.c
@@ -34,7 +34,7 @@ static int node_list(int argc, char **argv)
 
 	if (!raw_output)
 		printf("  Id   Host:Port         V-Nodes       Zone\n");
-	rb_for_each_entry(n, &sd_nroot, rb) {
+	rb_for_each_entry(n, &sd_nroot) {
 		const char *host = addr_to_str(n->nid.addr, n->nid.port);
 
 		printf(raw_output ? "%d %s %d %u\n" : "%4d   %-20s\t%2d%11u\n",
@@ -53,7 +53,7 @@ static int node_info(int argc, char **argv)
 	if (!raw_output)
 		printf("Id\tSize\tUsed\tAvail\tUse%%\n");
 
-	rb_for_each_entry(n, &sd_nroot, rb) {
+	rb_for_each_entry(n, &sd_nroot) {
 		struct sd_req req;
 		struct sd_rsp *rsp = (struct sd_rsp *)&req;
 
@@ -195,7 +195,7 @@ static int node_recovery(int argc, char **argv)
 		       "       Progress\n");
 	}
 
-	rb_for_each_entry(n, &sd_nroot, rb) {
+	rb_for_each_entry(n, &sd_nroot) {
 		struct sd_req req;
 		struct sd_rsp *rsp = (struct sd_rsp *)&req;
 		struct recovery_state state;
@@ -233,12 +233,12 @@ static int node_recovery(int argc, char **argv)
 	return EXIT_SUCCESS;
 }
 
-static struct sd_node *idx_to_node(struct rb_root *nroot, int idx)
+static struct sd_node *idx_to_node(struct rb_node_root *nroot, int idx)
 {
-	struct sd_node *n = rb_entry(rb_first(nroot), struct sd_node, rb);
+	struct sd_node *n = rb_first(nroot);
 
 	while (idx--)
-		n = rb_entry(rb_next(&n->rb), struct sd_node, rb);
+		n = rb_next(n, nroot);
 
 	return n;
 }
@@ -355,7 +355,7 @@ static int md_info(int argc, char **argv)
 	if (!node_cmd_data.all_nodes)
 		return node_md_info(&sd_nid);
 
-	rb_for_each_entry(n, &sd_nroot, rb) {
+	rb_for_each_entry(n, &sd_nroot) {
 		fprintf(stdout, "Node %d:\n", i++);
 		ret = node_md_info(&n->nid);
 		if (ret != EXIT_SUCCESS)
diff --git a/dog/trace.c b/dog/trace.c
index 806a3dd..9346849 100644
--- a/dog/trace.c
+++ b/dog/trace.c
@@ -252,8 +252,6 @@ struct graph_stat_entry {
 	uint16_t nr_calls;
 };
 
-static struct rb_root stat_tree_root;
-
 static LIST_HEAD(stat_list);
 
 static int graph_stat_cmp(const struct graph_stat_entry *a,
@@ -262,12 +260,15 @@ static int graph_stat_cmp(const struct graph_stat_entry *a,
 	return strcmp(a->fname, b->fname);
 }
 
+static RB_ROOT(, struct graph_stat_entry, rb) stat_tree_root =
+	RB_ROOT_INITIALIZER(graph_stat_cmp);
+
 static struct graph_stat_entry *
 stat_tree_insert(struct graph_stat_entry *new)
 {
 	struct graph_stat_entry *entry;
 
-	entry = rb_insert(&stat_tree_root, new, rb, graph_stat_cmp);
+	entry = rb_insert(new, &stat_tree_root);
 	if (entry) {
 		entry->duration += new->duration;
 		entry->nr_calls++;
diff --git a/dog/vdi.c b/dog/vdi.c
index a465e6a..402945c 100644
--- a/dog/vdi.c
+++ b/dog/vdi.c
@@ -320,7 +320,7 @@ static void parse_objs(uint64_t oid, obj_parser_func_t func, void *data, unsigne
 	char *buf;
 
 	buf = xzalloc(size);
-	rb_for_each_entry(n, &sd_nroot, rb) {
+	rb_for_each_entry(n, &sd_nroot) {
 		struct sd_req hdr;
 		struct sd_rsp *rsp = (struct sd_rsp *)&hdr;
 
@@ -922,8 +922,8 @@ static int do_track_object(uint64_t oid, uint8_t nr_copies)
 
 	nr_logs = rsp->data_length / sizeof(struct epoch_log);
 	for (i = nr_logs - 1; i >= 0; i--) {
-		struct rb_root vroot = RB_ROOT;
-		struct rb_root nroot = RB_ROOT;
+		struct rb_vnode_root vroot = RB_ROOT_INITIALIZER(vnode_cmp);
+		struct rb_node_root nroot = RB_ROOT_INITIALIZER(node_cmp);
 
 		printf("\nobj %"PRIx64" locations at epoch %d, copies = %d\n",
 		       oid, logs[i].epoch, nr_copies);
@@ -942,7 +942,7 @@ static int do_track_object(uint64_t oid, uint8_t nr_copies)
 			continue;
 		}
 		for (int k = 0; k < logs[i].nr_nodes; k++)
-			rb_insert(&nroot, &logs[i].nodes[k], rb, node_cmp);
+			rb_insert(&logs[i].nodes[k], &nroot);
 		nodes_to_vnodes(&nroot, &vroot);
 		oid_to_vnodes(oid, &vroot, nr_copies, vnode_buf);
 		for (j = 0; j < nr_copies; j++) {
@@ -950,7 +950,7 @@ static int do_track_object(uint64_t oid, uint8_t nr_copies)
 
 			printf("%s\n", addr_to_str(n->addr, n->port));
 		}
-		rb_destroy(&vroot, struct sd_vnode, rb);
+		rb_destroy(&vroot);
 	}
 
 	free(logs);
diff --git a/include/compiler.h b/include/compiler.h
index 324dacf..a071bd8 100644
--- a/include/compiler.h
+++ b/include/compiler.h
@@ -29,6 +29,11 @@
 	const typeof(((type *)0)->member) *__mptr = (ptr);	\
 	(type *)((char *)__mptr - offsetof(type, member)); })
 
+#define offset_container_of(ptr, type, offset)	\
+	(type *)((char *)(ptr) - (offset))
+#define offset_member_of(ptr, type, offset)	\
+	(type *)((char *)(ptr) + (offset))
+
 #define likely(x)       __builtin_expect(!!(x), 1)
 #define unlikely(x)     __builtin_expect(!!(x), 0)
 
diff --git a/include/internal_proto.h b/include/internal_proto.h
index 59c6e2a..c46b5a1 100644
--- a/include/internal_proto.h
+++ b/include/internal_proto.h
@@ -143,6 +143,8 @@ struct sd_node {
 	uint64_t        space;
 };
 
+RB_ROOT(rb_node_root, struct sd_node, rb);
+
 /*
  * A joining sheep multicasts the local cluster info.  Then, the existing nodes
  * reply the latest cluster info which is unique among all of the nodes.
diff --git a/include/rbtree.h b/include/rbtree.h
index 6aba6ad..a395cf5 100644
--- a/include/rbtree.h
+++ b/include/rbtree.h
@@ -12,7 +12,7 @@ struct rb_node {
 	struct rb_node *rb_left __attribute__ ((aligned (8)));
 };
 
-struct rb_root {
+struct __rb_root {
 	struct rb_node *rb_node;
 };
 
@@ -33,15 +33,38 @@ static inline void rb_set_color(struct rb_node *rb, int color)
 	rb->rb_parent_color = (rb->rb_parent_color & ~1) | color;
 }
 
-#define RB_ROOT { .rb_node = NULL }
-static inline void INIT_RB_ROOT(struct rb_root *root)
-{
-	root->rb_node = NULL;
+#define RB_ROOT_INITIALIZER(compar) { .cmp = compar }
+
+#define RB_ROOT(name, type, member)				\
+struct name {							\
+	struct __rb_root r;					\
+	int (*cmp)(const type *, const type *);			\
+								\
+	/* The below fields are used only at compile time */	\
+	type *t;						\
+	char o[offsetof(type, member)];				\
 }
 
-#define rb_entry(ptr, type, member) container_of(ptr, type, member)
+#define INIT_RB_ROOT(root, compar)		\
+({						\
+	(root)->r.rb_node = NULL;		\
+	(root)->cmp = compar;			\
+})
+
+#define rb_type_check(node, root) (void) ((node) == (root)->t)
+
+/* return NULL if ptr is NULL */
+#define rb_entry(ptr, root)						\
+	offset_container_of(ptr ?: rb_node_of(NULL, root),		\
+			    typeof(*(root)->t), sizeof((root)->o))
+
+#define rb_node_of(ptr, root)						\
+({									\
+	rb_type_check(ptr, root);					\
+	offset_member_of(ptr, struct rb_node, sizeof((root)->o));	\
+})
 
-#define RB_EMPTY_ROOT(root)     ((root)->rb_node == NULL)
+#define RB_EMPTY_ROOT(root)     ((root)->r.rb_node == NULL)
 #define RB_EMPTY_NODE(node)     (rb_parent(node) == node)
 #define RB_CLEAR_NODE(node)     (rb_set_parent(node, node))
 
@@ -53,21 +76,41 @@ static inline void rb_init_node(struct rb_node *rb)
 	RB_CLEAR_NODE(rb);
 }
 
-void rb_insert_color(struct rb_node *, struct rb_root *);
-void rb_erase(struct rb_node *, struct rb_root *);
+void __rb_insert_color(struct rb_node *, struct __rb_root *);
+void __rb_erase(struct rb_node *, struct __rb_root *);
+
+#define rb_erase(node, root) \
+	__rb_erase(rb_node_of(node, root), &(root)->r)
+
+#define rb_insert_color(node, root) \
+	__rb_insert_color(rb_node_of(node, root), &(root)->r)
 
 /* Find logical next and previous nodes in a tree */
-struct rb_node *rb_next(const struct rb_node *);
-struct rb_node *rb_prev(const struct rb_node *);
-struct rb_node *rb_first(const struct rb_root *);
-struct rb_node *rb_last(const struct rb_root *);
+struct rb_node *__rb_next(const struct rb_node *);
+struct rb_node *__rb_prev(const struct rb_node *);
+struct rb_node *__rb_first(const struct __rb_root *);
+struct rb_node *__rb_last(const struct __rb_root *);
+
+#define rb_next(node, root) \
+	rb_entry(__rb_next(rb_node_of(node, root)), root)
+
+#define rb_prev(node, root) \
+	rb_entry(__rb_prev(rb_node_of(node, root)), root)
+
+#define rb_first(root) rb_entry(__rb_first(&(root)->r), root)
+#define rb_last(root) rb_entry(__rb_last(&(root)->r), root)
 
 /* Fast replacement of a single node without remove/rebalance/add/rebalance */
-void rb_replace_node(struct rb_node *victim, struct rb_node *new,
-		struct rb_root *root);
+void __rb_replace_node(struct rb_node *victim, struct rb_node *new,
+		       struct __rb_root *root);
+
+#define rb_replace_node(victim, new, root)				\
+	__rb_replace_node(rb_node_of(victim, root),			\
+			  rb_node_of(new, root),			\
+			  &(root)->r)
 
 static inline void rb_link_node(struct rb_node *node, struct rb_node *parent,
-		struct rb_node **rb_link)
+				struct rb_node **rb_link)
 {
 	node->rb_parent_color = (unsigned long)parent;
 	node->rb_left = node->rb_right = NULL;
@@ -79,14 +122,14 @@ static inline void rb_link_node(struct rb_node *node, struct rb_node *parent,
  * Search for a value in the rbtree.  This returns NULL when the key is not
  * found in the rbtree.
  */
-#define rb_search(root, key, member, compar)				\
+#define rb_search(key, root)						\
 ({									\
-	struct rb_node *__n = (root)->rb_node;				\
+	struct rb_node *__n = (root)->r.rb_node;			\
 	typeof(key) __ret = NULL, __data;				\
 									\
 	while (__n) {							\
-		__data = rb_entry(__n, typeof(*key), member);		\
-		int __cmp = compar(key, __data);			\
+		__data = rb_entry(__n, root);				\
+		int __cmp = (root)->cmp(key, __data);			\
 									\
 		if (__cmp < 0)						\
 			__n = __n->rb_left;				\
@@ -104,14 +147,14 @@ static inline void rb_link_node(struct rb_node *node, struct rb_node *parent,
  * Insert a new node into the rbtree.  This returns NULL on success, or the
  * existing node on error.
  */
-#define rb_insert(root, new, member, compar)				\
+#define rb_insert(new, root)						\
 ({									\
-	struct rb_node **__n = &(root)->rb_node, *__parent = NULL;	\
+	struct rb_node **__n = &(root)->r.rb_node, *__parent = NULL;	\
 	typeof(new) __old = NULL, __data;				\
 									\
 	while (*__n) {							\
-		__data = rb_entry(*__n, typeof(*new), member);		\
-		int __cmp = compar(new, __data);			\
+		__data = rb_entry(*__n, root);				\
+		int __cmp = (root)->cmp(new, __data);			\
 									\
 		__parent = *__n;					\
 		if (__cmp < 0)						\
@@ -126,8 +169,8 @@ static inline void rb_link_node(struct rb_node *node, struct rb_node *parent,
 									\
 	if (__old == NULL) {						\
 		/* Add new node and rebalance tree. */			\
-		rb_link_node(&((new)->member), __parent, __n);		\
-		rb_insert_color(&((new)->member), root);		\
+		rb_link_node(rb_node_of(new, root), __parent, __n);	\
+		rb_insert_color(new, root);				\
 	}								\
 									\
 	__old;							\
@@ -140,14 +183,14 @@ static inline void rb_link_node(struct rb_node *node, struct rb_node *parent,
  *
  * For an empty tree, we return NULL.
  */
-#define rb_nsearch(root, key, member, compar)                           \
+#define rb_nsearch(key, root)						\
 ({                                                                      \
-        struct rb_node *__n = (root)->rb_node;                          \
-        typeof(key) __ret = NULL, __data;                               \
+        struct rb_node *__n = (root)->r.rb_node;			\
+        typeof(key) __ret = NULL, __data;				\
                                                                         \
-        while (__n) {                                                   \
-                __data = rb_entry(__n, typeof(*key), member);           \
-                int __cmp = compar(key, __data);                        \
+        while (__n) {							\
+                __data = rb_entry(__n, root);				\
+                int __cmp = (root)->cmp(key, __data);			\
                                                                         \
                 if (__cmp < 0) {                                        \
                         __ret = __data;                                 \
@@ -159,42 +202,41 @@ static inline void rb_link_node(struct rb_node *node, struct rb_node *parent,
                         break;                                          \
                 }                                                       \
         }                                                               \
-        if (!__ret && !RB_EMPTY_ROOT(root))                             \
-                __ret = rb_entry(rb_first(root), typeof(*key), member); \
-        __ret;                                                          \
+        if (!__ret && !RB_EMPTY_ROOT(root))				\
+                __ret = rb_first(root);					\
+        __ret;								\
 })
 
 /* Iterate over a rbtree safe against removal of rbnode */
 #define rb_for_each(pos, root)						\
-	for (struct rb_node *LOCAL(n) = (pos = rb_first(root), NULL);	\
-	     pos && (LOCAL(n) = rb_next(pos), 1);			\
+	for (struct rb_node *LOCAL(n) = (pos = __rb_first(root), NULL);	\
+	     pos && (LOCAL(n) = __rb_next(pos), 1);			\
 	     pos = LOCAL(n))
 
 /* Iterate over a rbtree of given type safe against removal of rbnode */
-#define rb_for_each_entry(pos, root, member)				\
-	for (struct rb_node *LOCAL(p) = rb_first(root), *LOCAL(n);	\
-	     LOCAL(p) && (LOCAL(n) = rb_next(LOCAL(p)), 1) &&		\
-		     (pos = rb_entry(LOCAL(p), typeof(*pos), member), 1); \
-	     LOCAL(p) = LOCAL(n))
+#define rb_for_each_entry(pos, root)					\
+	for (typeof(pos) LOCAL(n) = (pos = rb_first(root), NULL);	\
+	     pos && (LOCAL(n) = rb_next(pos, root), 1);			\
+	     pos = LOCAL(n))
 
 /* Destroy the tree and free the memory */
-#define rb_destroy(root, type, member)					\
+#define rb_destroy(root)						\
 ({									\
-	type *__dummy;							\
-	rb_for_each_entry(__dummy, root, member) {			\
-		rb_erase(&__dummy->member, root);			\
+	typeof((root)->t) __dummy;					\
+	rb_for_each_entry(__dummy, root) {				\
+		rb_erase(__dummy, root);				\
 		free(__dummy);						\
 	}								\
 })
 
 /* Copy the tree 'root' as 'outroot' */
-#define rb_copy(root, type, member, outroot, compar)			\
+#define rb_copy(root, outroot)						\
 ({									\
-	type *__src, *__dst;						\
-	rb_for_each_entry(__src, root, member) {			\
+	typeof(*(root)->t) *__src, *__dst;				\
+	rb_for_each_entry(__src, root) {				\
 		__dst = xmalloc(sizeof(*__dst));			\
 		*__dst = *__src;					\
-		rb_insert(outroot, __dst, member, compar);		\
+		rb_insert(__dst, outroot);				\
 	}								\
 })
 
diff --git a/include/sheep.h b/include/sheep.h
index 293e057..e4c9b01 100644
--- a/include/sheep.h
+++ b/include/sheep.h
@@ -20,15 +20,22 @@
 #include "net.h"
 #include "rbtree.h"
 
+struct sd_vnode;
+
+static inline int vnode_cmp(const struct sd_vnode *node1,
+			    const struct sd_vnode *node2);
+
 struct sd_vnode {
 	struct rb_node rb;
 	const struct sd_node *node;
 	uint64_t hash;
 };
 
+RB_ROOT(rb_vnode_root, struct sd_vnode, rb);
+
 struct vnode_info {
-	struct rb_root vroot;
-	struct rb_root nroot;
+	struct rb_vnode_root vroot;
+	struct rb_node_root nroot;
 	int nr_nodes;
 	int nr_zones;
 	refcnt_t refcnt;
@@ -70,27 +77,26 @@ static inline int vnode_cmp(const struct sd_vnode *node1,
 
 /* If v1_hash < oid_hash <= v2_hash, then oid is resident on v2 */
 static inline struct sd_vnode *
-oid_to_first_vnode(uint64_t oid, struct rb_root *root)
+oid_to_first_vnode(uint64_t oid, const struct rb_vnode_root *root)
 {
 	struct sd_vnode dummy = {
 		.hash = sd_hash_oid(oid),
 	};
-	return rb_nsearch(root, &dummy, rb, vnode_cmp);
+	return rb_nsearch(&dummy, root);
 }
 
 /* Replica are placed along the ring one by one with different zones */
-static inline void oid_to_vnodes(uint64_t oid, struct rb_root *root,
-				 int nr_copies,
-				 const struct sd_vnode **vnodes)
+static inline void oid_to_vnodes(uint64_t oid, const struct rb_vnode_root *root,
+				 int nr_copies, const struct sd_vnode **vnodes)
 {
 	const struct sd_vnode *next = oid_to_first_vnode(oid, root);
 
 	vnodes[0] = next;
 	for (int i = 1; i < nr_copies; i++) {
 next:
-		next = rb_entry(rb_next(&next->rb), struct sd_vnode, rb);
+		next = rb_next(next, root);
 		if (!next) /* Wrap around */
-			next = rb_entry(rb_first(root), struct sd_vnode, rb);
+			next = rb_first(root);
 		if (unlikely(next == vnodes[0]))
 			panic("can't find a valid vnode");
 		for (int j = 0; j < i; j++)
@@ -101,7 +107,7 @@ next:
 }
 
 static inline const struct sd_vnode *
-oid_to_vnode(uint64_t oid, struct rb_root *root, int copy_idx)
+oid_to_vnode(uint64_t oid, const struct rb_vnode_root *root, int copy_idx)
 {
 	const struct sd_vnode *vnodes[SD_MAX_COPIES];
 
@@ -111,7 +117,7 @@ oid_to_vnode(uint64_t oid, struct rb_root *root, int copy_idx)
 }
 
 static inline const struct sd_node *
-oid_to_node(uint64_t oid, struct rb_root *root, int copy_idx)
+oid_to_node(uint64_t oid, const struct rb_vnode_root *root, int copy_idx)
 {
 	const struct sd_vnode *vnode;
 
@@ -120,9 +126,8 @@ oid_to_node(uint64_t oid, struct rb_root *root, int copy_idx)
 	return vnode->node;
 }
 
-static inline void oid_to_nodes(uint64_t oid, struct rb_root *root,
-				int nr_copies,
-				const struct sd_node **nodes)
+static inline void oid_to_nodes(uint64_t oid, const struct rb_vnode_root *root,
+				int nr_copies, const struct sd_node **nodes)
 {
 	const struct sd_vnode *vnodes[SD_MAX_COPIES];
 
@@ -219,7 +224,7 @@ static inline bool node_eq(const struct sd_node *a, const struct sd_node *b)
 }
 
 static inline void
-node_to_vnodes(const struct sd_node *n, struct rb_root *vroot)
+node_to_vnodes(const struct sd_node *n, struct rb_vnode_root *vroot)
 {
 	uint64_t hval = sd_hash(&n->nid, offsetof(typeof(n->nid),
 						  io_addr));
@@ -230,25 +235,26 @@ node_to_vnodes(const struct sd_node *n, struct rb_root *vroot)
 		hval = sd_hash_next(hval);
 		v->hash = hval;
 		v->node = n;
-		if (unlikely(rb_insert(vroot, v, rb, vnode_cmp)))
+		if (unlikely(rb_insert(v, vroot)))
 			panic("vdisk hash collison");
 	}
 }
 
 static inline void
-nodes_to_vnodes(struct rb_root *nroot, struct rb_root *vroot)
+nodes_to_vnodes(const struct rb_node_root *nroot, struct rb_vnode_root *vroot)
 {
-	struct sd_node *n;
+	const struct sd_node *n;
 
-	rb_for_each_entry(n, nroot, rb)
+	rb_for_each_entry(n, nroot)
 		node_to_vnodes(n, vroot);
 }
 
-static inline void nodes_to_buffer(struct rb_root *nroot, void *buffer)
+static inline void nodes_to_buffer(const struct rb_node_root *nroot,
+				   void *buffer)
 {
 	struct sd_node *n, *buf = buffer;
 
-	rb_for_each_entry(n, nroot, rb) {
+	rb_for_each_entry(n, nroot) {
 		memcpy(buf++, n, sizeof(*n));
 	}
 }
diff --git a/include/sockfd_cache.h b/include/sockfd_cache.h
index 21cc2bf..64ed36e 100644
--- a/include/sockfd_cache.h
+++ b/include/sockfd_cache.h
@@ -9,7 +9,7 @@ void sockfd_cache_put(const struct node_id *nid, struct sockfd *sfd);
 void sockfd_cache_del_node(const struct node_id *nid);
 void sockfd_cache_del(const struct node_id *nid, struct sockfd *sfd);
 void sockfd_cache_add(const struct node_id *nid);
-void sockfd_cache_add_group(const struct rb_root *nroot);
+void sockfd_cache_add_group(const struct rb_node_root *nroot);
 
 int sockfd_init(void);
 
diff --git a/lib/event.c b/lib/event.c
index 88078f4..af04417 100644
--- a/lib/event.c
+++ b/lib/event.c
@@ -20,7 +20,6 @@
 #include "event.h"
 
 static int efd;
-static struct rb_root events_tree = RB_ROOT;
 
 static void timer_handler(int fd, int events, void *data)
 {
@@ -76,6 +75,9 @@ static int event_cmp(const struct event_info *e1, const struct event_info *e2)
 	return intcmp(e1->fd, e2->fd);
 }
 
+static RB_ROOT(, struct event_info, rb) events_tree =
+	RB_ROOT_INITIALIZER(event_cmp);
+
 int init_event(int nr)
 {
 	nr_events = nr;
@@ -93,7 +95,7 @@ static struct event_info *lookup_event(int fd)
 {
 	struct event_info key = { .fd = fd };
 
-	return rb_search(&events_tree, &key, rb, event_cmp);
+	return rb_search(&key, &events_tree);
 }
 
 int register_event_prio(int fd, event_handler_t h, void *data, int prio)
@@ -117,7 +119,7 @@ int register_event_prio(int fd, event_handler_t h, void *data, int prio)
 		sd_err("failed to add epoll event: %m");
 		free(ei);
 	} else
-		rb_insert(&events_tree, ei, rb, event_cmp);
+		rb_insert(ei, &events_tree);
 
 	return ret;
 }
@@ -135,7 +137,7 @@ void unregister_event(int fd)
 	if (ret)
 		sd_err("failed to delete epoll event for fd %d: %m", fd);
 
-	rb_erase(&ei->rb, &events_tree);
+	rb_erase(ei, &events_tree);
 	free(ei);
 
 	/*
diff --git a/lib/rbtree.c b/lib/rbtree.c
index 1e54659..86be567 100644
--- a/lib/rbtree.c
+++ b/lib/rbtree.c
@@ -22,7 +22,7 @@
 #include <unistd.h>
 #include "rbtree.h"
 
-static void __rb_rotate_left(struct rb_node *node, struct rb_root *root)
+static void __rb_rotate_left(struct rb_node *node, struct __rb_root *root)
 {
 	struct rb_node *right = node->rb_right;
 	struct rb_node *parent = rb_parent(node);
@@ -44,7 +44,7 @@ static void __rb_rotate_left(struct rb_node *node, struct rb_root *root)
 	rb_set_parent(node, right);
 }
 
-static void __rb_rotate_right(struct rb_node *node, struct rb_root *root)
+static void __rb_rotate_right(struct rb_node *node, struct __rb_root *root)
 {
 	struct rb_node *left = node->rb_left;
 	struct rb_node *parent = rb_parent(node);
@@ -66,7 +66,7 @@ static void __rb_rotate_right(struct rb_node *node, struct rb_root *root)
 	rb_set_parent(node, left);
 }
 
-void rb_insert_color(struct rb_node *node, struct rb_root *root)
+void __rb_insert_color(struct rb_node *node, struct __rb_root *root)
 {
 	struct rb_node *parent, *gparent;
 
@@ -122,7 +122,7 @@ void rb_insert_color(struct rb_node *node, struct rb_root *root)
 }
 
 static void __rb_erase_color(struct rb_node *node, struct rb_node *parent,
-			     struct rb_root *root)
+			     struct __rb_root *root)
 {
 	struct rb_node *other;
 
@@ -187,7 +187,7 @@ static void __rb_erase_color(struct rb_node *node, struct rb_node *parent,
 		rb_set_black(node);
 }
 
-void rb_erase(struct rb_node *node, struct rb_root *root)
+void __rb_erase(struct rb_node *node, struct __rb_root *root)
 {
 	struct rb_node *child, *parent;
 	int color;
@@ -252,7 +252,7 @@ void rb_erase(struct rb_node *node, struct rb_root *root)
 }
 
 /* This function returns the first node (in sort order) of the tree. */
-struct rb_node *rb_first(const struct rb_root *root)
+struct rb_node *__rb_first(const struct __rb_root *root)
 {
 	struct rb_node	*n;
 
@@ -264,7 +264,7 @@ struct rb_node *rb_first(const struct rb_root *root)
 	return n;
 }
 
-struct rb_node *rb_last(const struct rb_root *root)
+struct rb_node *__rb_last(const struct __rb_root *root)
 {
 	struct rb_node	*n;
 
@@ -276,7 +276,7 @@ struct rb_node *rb_last(const struct rb_root *root)
 	return n;
 }
 
-struct rb_node *rb_next(const struct rb_node *node)
+struct rb_node *__rb_next(const struct rb_node *node)
 {
 	struct rb_node *parent;
 
@@ -308,7 +308,7 @@ struct rb_node *rb_next(const struct rb_node *node)
 	return parent;
 }
 
-struct rb_node *rb_prev(const struct rb_node *node)
+struct rb_node *__rb_prev(const struct rb_node *node)
 {
 	struct rb_node *parent;
 
@@ -336,8 +336,8 @@ struct rb_node *rb_prev(const struct rb_node *node)
 	return parent;
 }
 
-void rb_replace_node(struct rb_node *victim, struct rb_node *new,
-		     struct rb_root *root)
+void __rb_replace_node(struct rb_node *victim, struct rb_node *new,
+		       struct __rb_root *root)
 {
 	struct rb_node *parent = rb_parent(victim);
 
diff --git a/lib/sockfd_cache.c b/lib/sockfd_cache.c
index 0bfb274..fc7e31f 100644
--- a/lib/sockfd_cache.c
+++ b/lib/sockfd_cache.c
@@ -35,17 +35,6 @@
 #include "util.h"
 #include "sheep.h"
 
-struct sockfd_cache {
-	struct rb_root root;
-	struct sd_lock lock;
-	int count;
-};
-
-static struct sockfd_cache sockfd_cache = {
-	.root = RB_ROOT,
-	.lock = SD_LOCK_INITIALIZER,
-};
-
 /*
  * Suppose request size from Guest is 512k, then 4M / 512k = 8, so at
  * most 8 requests can be issued to the same sheep object. Based on this
@@ -78,17 +67,28 @@ static int sockfd_cache_cmp(const struct sockfd_cache_entry *a,
 	return node_id_cmp(&a->nid, &b->nid);
 }
 
+struct sockfd_cache {
+	RB_ROOT(, struct sockfd_cache_entry, rb) root;
+	struct sd_lock lock;
+	int count;
+};
+
+static struct sockfd_cache sockfd_cache = {
+	.root = RB_ROOT_INITIALIZER(sockfd_cache_cmp),
+	.lock = SD_LOCK_INITIALIZER,
+};
+
 static struct sockfd_cache_entry *
 sockfd_cache_insert(struct sockfd_cache_entry *new)
 {
-	return rb_insert(&sockfd_cache.root, new, rb, sockfd_cache_cmp);
+	return rb_insert(new, &sockfd_cache.root);
 }
 
 static struct sockfd_cache_entry *sockfd_cache_search(const struct node_id *nid)
 {
 	struct sockfd_cache_entry key = { .nid = *nid };
 
-	return rb_search(&sockfd_cache.root, &key, rb, sockfd_cache_cmp);
+	return rb_search(&key, &sockfd_cache.root);
 }
 
 static inline int get_free_slot(struct sockfd_cache_entry *entry)
@@ -175,7 +175,7 @@ static bool sockfd_cache_destroy(const struct node_id *nid)
 		goto false_out;
 	}
 
-	rb_erase(&entry->rb, &sockfd_cache.root);
+	rb_erase(entry, &sockfd_cache.root);
 	sd_unlock(&sockfd_cache.lock);
 
 	destroy_all_slots(entry);
@@ -205,12 +205,12 @@ static void sockfd_cache_add_nolock(const struct node_id *nid)
 }
 
 /* Add group of nodes to the cache */
-void sockfd_cache_add_group(const struct rb_root *nroot)
+void sockfd_cache_add_group(const struct rb_node_root *nroot)
 {
 	struct sd_node *n;
 
 	sd_write_lock(&sockfd_cache.lock);
-	rb_for_each_entry(n, nroot, rb) {
+	rb_for_each_entry(n, nroot) {
 		sockfd_cache_add_nolock(&n->nid);
 	}
 	sd_unlock(&sockfd_cache.lock);
@@ -254,7 +254,7 @@ static void do_grow_fds(struct work *work)
 	old_fds_count = fds_count;
 	new_fds_count = fds_count * 2;
 	new_size = sizeof(struct sockfd_cache_fd) * fds_count * 2;
-	rb_for_each_entry(entry, &sockfd_cache.root, rb) {
+	rb_for_each_entry(entry, &sockfd_cache.root) {
 		entry->fds = xrealloc(entry->fds, new_size);
 		for (i = old_fds_count; i < new_fds_count; i++) {
 			entry->fds[i].fd = -1;
diff --git a/sheep/cluster.h b/sheep/cluster.h
index a267443..6a60033 100644
--- a/sheep/cluster.h
+++ b/sheep/cluster.h
@@ -163,16 +163,16 @@ static inline const char *get_cdrv_option(const struct cluster_driver *cdrv,
 
 /* callbacks back into sheepdog from the cluster drivers */
 void sd_accept_handler(const struct sd_node *joined,
-		       const struct rb_root *nroot, size_t nr_members,
+		       const struct rb_node_root *nroot, size_t nr_members,
 		       const void *opaque);
-void sd_leave_handler(const struct sd_node *left, const struct rb_root *nroot,
-		      size_t nr_members);
+void sd_leave_handler(const struct sd_node *left,
+		      const struct rb_node_root *nroot, size_t nr_members);
 void sd_notify_handler(const struct sd_node *sender, void *msg, size_t msg_len);
 bool sd_block_handler(const struct sd_node *sender);
 int sd_reconnect_handler(void);
 void sd_update_node_handler(struct sd_node *);
 bool sd_join_handler(const struct sd_node *joining,
-		     const struct rb_root *nroot, size_t nr_nodes,
+		     const struct rb_node_root *nroot, size_t nr_nodes,
 		     void *opaque);
 
 #endif
diff --git a/sheep/cluster/corosync.c b/sheep/cluster/corosync.c
index 45756e8..acc9e1c 100644
--- a/sheep/cluster/corosync.c
+++ b/sheep/cluster/corosync.c
@@ -239,10 +239,10 @@ find_event(enum corosync_event_type type, struct cpg_node *sender)
 }
 
 static void build_node_list(struct cpg_node *nodes, size_t nr_nodes,
-			    struct rb_root *nroot)
+			    struct rb_node_root *nroot)
 {
 	for (int i = 0; i < nr_nodes; i++)
-		rb_insert(nroot, &nodes[i].node, rb, node_cmp);
+		rb_insert(&nodes[i].node, nroot);
 }
 
 /*
@@ -254,7 +254,7 @@ static bool __corosync_dispatch_one(struct corosync_event *cevent)
 {
 	struct sd_node *node;
 	struct cpg_node *n;
-	struct rb_root nroot = RB_ROOT;
+	struct rb_node_root nroot = RB_ROOT_INITIALIZER(node_cmp);
 	int idx;
 
 	switch (cevent->type) {
diff --git a/sheep/cluster/local.c b/sheep/cluster/local.c
index 824279c..56880af 100644
--- a/sheep/cluster/local.c
+++ b/sheep/cluster/local.c
@@ -93,9 +93,9 @@ static struct shm_queue {
 	struct local_event nonblock_events[MAX_EVENTS];
 } *shm_queue;
 
-static inline void node_insert(struct sd_node *new, struct rb_root *root)
+static inline void node_insert(struct sd_node *new, struct rb_node_root *root)
 {
-	if (rb_insert(root, new, rb, node_cmp))
+	if (rb_insert(new, root))
 		panic("insert duplicate %s", node_to_str(new));
 }
 
@@ -388,7 +388,7 @@ static bool local_process_event(void)
 {
 	struct local_event *ev;
 	int i;
-	struct rb_root root = RB_ROOT;
+	struct rb_node_root root = RB_ROOT_INITIALIZER(node_cmp);
 	size_t nr_nodes = 0;
 
 	ev = shm_queue_peek();
@@ -432,7 +432,7 @@ static bool local_process_event(void)
 	case EVENT_JOIN:
 		for (i = 0; i < ev->nr_lnodes; i++)
 			if (node_eq(&ev->sender.node, &ev->lnodes[i].node)) {
-				rb_erase(&ev->lnodes[i].node.rb, &root);
+				rb_erase(&ev->lnodes[i].node, &root);
 				nr_nodes--;
 			}
 		if (sd_join_handler(&ev->sender.node, &root, nr_nodes,
diff --git a/sheep/cluster/zookeeper.c b/sheep/cluster/zookeeper.c
index 7ce8180..b3966bc 100644
--- a/sheep/cluster/zookeeper.c
+++ b/sheep/cluster/zookeeper.c
@@ -69,9 +69,8 @@ struct zk_event {
 	uint8_t buf[ZK_MAX_BUF_SIZE];
 };
 
-static struct rb_root sd_node_root = RB_ROOT;
+static struct rb_node_root sd_node_root = RB_ROOT_INITIALIZER(node_cmp);
 static size_t nr_sd_nodes;
-static struct rb_root zk_node_root = RB_ROOT;
 static struct sd_lock zk_tree_lock = SD_LOCK_INITIALIZER;
 static struct sd_lock zk_compete_master_lock = SD_LOCK_INITIALIZER;
 static LIST_HEAD(zk_block_list);
@@ -87,16 +86,19 @@ static int zk_node_cmp(const struct zk_node *a, const struct zk_node *b)
 	return node_id_cmp(&a->node.nid, &b->node.nid);
 }
 
+static RB_ROOT(, struct zk_node, rb) zk_node_root =
+	RB_ROOT_INITIALIZER(zk_node_cmp);
+
 static struct zk_node *zk_tree_insert(struct zk_node *new)
 {
-	return rb_insert(&zk_node_root, new, rb, zk_node_cmp);
+	return rb_insert(new, &zk_node_root);
 }
 
 static struct zk_node *zk_tree_search_nolock(const struct node_id *nid)
 {
 	struct zk_node key = { .node.nid = *nid };
 
-	return rb_search(&zk_node_root, &key, rb, zk_node_cmp);
+	return rb_search(&key, &zk_node_root);
 }
 
 static inline struct zk_node *zk_tree_search(const struct node_id *nid)
@@ -379,7 +381,7 @@ static int push_join_response(struct zk_event *ev)
 
 	ev->type = EVENT_ACCEPT;
 	ev->nr_nodes = nr_sd_nodes;
-	rb_for_each_entry(n, &sd_node_root, rb) {
+	rb_for_each_entry(n, &sd_node_root) {
 		memcpy(np++, n, sizeof(struct sd_node));
 	}
 	queue_pos--;
@@ -421,7 +423,7 @@ static inline void zk_tree_add(struct zk_node *node)
 	 * Even node list will be built later, we need this because in master
 	 * transfer case, we need this information to destroy the tree.
 	 */
-	rb_insert(&sd_node_root, &zk->node, rb, node_cmp);
+	rb_insert(&zk->node, &sd_node_root);
 	nr_sd_nodes++;
 out:
 	sd_unlock(&zk_tree_lock);
@@ -430,7 +432,7 @@ out:
 static inline void zk_tree_del(struct zk_node *node)
 {
 	sd_write_lock(&zk_tree_lock);
-	rb_erase(&node->rb, &zk_node_root);
+	rb_erase(node, &zk_node_root);
 	free(node);
 	sd_unlock(&zk_tree_lock);
 }
@@ -438,7 +440,7 @@ static inline void zk_tree_del(struct zk_node *node)
 static inline void zk_tree_destroy(void)
 {
 	sd_write_lock(&zk_tree_lock);
-	rb_destroy(&zk_node_root, struct zk_node, rb);
+	rb_destroy(&zk_node_root);
 	sd_unlock(&zk_tree_lock);
 }
 
@@ -447,9 +449,9 @@ static inline void build_node_list(void)
 	struct zk_node *zk;
 
 	nr_sd_nodes = 0;
-	INIT_RB_ROOT(&sd_node_root);
-	rb_for_each_entry(zk, &zk_node_root, rb) {
-		rb_insert(&sd_node_root, &zk->node, rb, node_cmp);
+	INIT_RB_ROOT(&sd_node_root, node_cmp);
+	rb_for_each_entry(zk, &zk_node_root) {
+		rb_insert(&zk->node, &sd_node_root);
 		nr_sd_nodes++;
 	}
 
@@ -996,10 +998,10 @@ static inline void handle_session_expire(void)
 	/* clean memory states */
 	close(efd);
 	zk_tree_destroy();
-	INIT_RB_ROOT(&zk_node_root);
+	INIT_RB_ROOT(&zk_node_root, zk_node_cmp);
 	INIT_LIST_HEAD(&zk_block_list);
 	nr_sd_nodes = 0;
-	INIT_RB_ROOT(&sd_node_root);
+	INIT_RB_ROOT(&sd_node_root, node_cmp);
 	first_push = true;
 	joined = false;
 
diff --git a/sheep/group.c b/sheep/group.c
index 5e90fd5..718258f 100644
--- a/sheep/group.c
+++ b/sheep/group.c
@@ -20,7 +20,7 @@ struct get_vdis_work {
 	struct work work;
 	DECLARE_BITMAP(vdi_inuse, SD_NR_VDIS);
 	struct sd_node joined;
-	struct rb_root nroot;
+	struct rb_node_root nroot;
 };
 
 static pthread_mutex_t wait_vdis_lock = PTHREAD_MUTEX_INITIALIZER;
@@ -31,13 +31,13 @@ static main_thread(struct vnode_info *) current_vnode_info;
 static main_thread(struct list_head *) pending_block_list;
 static main_thread(struct list_head *) pending_notify_list;
 
-static int get_zones_nr_from(struct rb_root *nroot)
+static int get_zones_nr_from(struct rb_node_root *nroot)
 {
 	int nr_zones = 0, j;
 	uint32_t zones[SD_MAX_COPIES];
 	struct sd_node *n;
 
-	rb_for_each_entry(n, nroot, rb) {
+	rb_for_each_entry(n, nroot) {
 		/*
 		 * Only count zones that actually store data, pure gateways
 		 * don't contribute to the redundancy level.
@@ -93,21 +93,21 @@ void put_vnode_info(struct vnode_info *vnode_info)
 {
 	if (vnode_info) {
 		if (refcount_dec(&vnode_info->refcnt) == 0) {
-			rb_destroy(&vnode_info->vroot, struct sd_vnode, rb);
-			rb_destroy(&vnode_info->nroot, struct sd_node, rb);
+			rb_destroy(&vnode_info->vroot);
+			rb_destroy(&vnode_info->nroot);
 			free(vnode_info);
 		}
 	}
 }
 
-static void recalculate_vnodes(struct rb_root *nroot)
+static void recalculate_vnodes(struct rb_node_root *nroot)
 {
 	int nr_non_gateway_nodes = 0;
 	uint64_t avg_size = 0;
 	struct sd_node *n;
 	float factor;
 
-	rb_for_each_entry(n, nroot, rb) {
+	rb_for_each_entry(n, nroot) {
 		if (n->space) {
 			avg_size += n->space;
 			nr_non_gateway_nodes++;
@@ -119,7 +119,7 @@ static void recalculate_vnodes(struct rb_root *nroot)
 
 	avg_size /= nr_non_gateway_nodes;
 
-	rb_for_each_entry(n, nroot, rb) {
+	rb_for_each_entry(n, nroot) {
 		factor = (float)n->space / (float)avg_size;
 		n->nr_vnodes = rintf(SD_DEFAULT_VNODES * factor);
 		sd_debug("node %s has %d vnodes, free space %" PRIu64,
@@ -127,19 +127,19 @@ static void recalculate_vnodes(struct rb_root *nroot)
 	}
 }
 
-struct vnode_info *alloc_vnode_info(const struct rb_root *nroot)
+struct vnode_info *alloc_vnode_info(const struct rb_node_root *nroot)
 {
 	struct vnode_info *vnode_info;
 	struct sd_node *n;
 
 	vnode_info = xzalloc(sizeof(*vnode_info));
 
-	INIT_RB_ROOT(&vnode_info->vroot);
-	INIT_RB_ROOT(&vnode_info->nroot);
-	rb_for_each_entry(n, nroot, rb) {
+	INIT_RB_ROOT(&vnode_info->vroot, vnode_cmp);
+	INIT_RB_ROOT(&vnode_info->nroot, node_cmp);
+	rb_for_each_entry(n, nroot) {
 		struct sd_node *new = xmalloc(sizeof(*new));
 		*new = *n;
-		if (unlikely(rb_insert(&vnode_info->nroot, new, rb, node_cmp)))
+		if (unlikely(rb_insert(new, &vnode_info->nroot)))
 			panic("node hash collision");
 		vnode_info->nr_nodes++;
 	}
@@ -156,7 +156,7 @@ struct vnode_info *get_vnode_info_epoch(uint32_t epoch,
 					struct vnode_info *cur_vinfo)
 {
 	struct sd_node nodes[SD_MAX_NODES];
-	struct rb_root nroot = RB_ROOT;
+	struct rb_node_root nroot = RB_ROOT_INITIALIZER(node_cmp);
 	int nr_nodes;
 
 	nr_nodes = epoch_log_read(epoch, nodes, sizeof(nodes));
@@ -167,7 +167,7 @@ struct vnode_info *get_vnode_info_epoch(uint32_t epoch,
 			return NULL;
 	}
 	for (int i = 0; i < nr_nodes; i++)
-		rb_insert(&nroot, &nodes[i], rb, node_cmp);
+		rb_insert(&nodes[i], &nroot);
 
 	return alloc_vnode_info(&nroot);
 }
@@ -339,7 +339,7 @@ int epoch_log_read_remote(uint32_t epoch, struct sd_node *nodes, int len,
 	const struct sd_node *node;
 	int ret;
 
-	rb_for_each_entry(node, &vinfo->nroot, rb) {
+	rb_for_each_entry(node, &vinfo->nroot) {
 		struct sd_req hdr;
 		struct sd_rsp *rsp = (struct sd_rsp *)&hdr;
 		int nodes_len;
@@ -392,13 +392,13 @@ static bool cluster_ctime_check(const struct cluster_info *cinfo)
  */
 static bool enough_nodes_gathered(struct cluster_info *cinfo,
 				  const struct sd_node *joining,
-				  const struct rb_root *nroot,
+				  const struct rb_node_root *nroot,
 				  size_t nr_nodes)
 {
 	for (int i = 0; i < cinfo->nr_nodes; i++) {
 		const struct sd_node *key = cinfo->nodes + i, *n;
 
-		n = rb_search(nroot, key, rb, node_cmp);
+		n = rb_search(key, nroot);
 		if (n == NULL && !node_eq(key, joining)) {
 			sd_debug("%s doesn't join yet", node_to_str(key));
 			return false;
@@ -423,7 +423,7 @@ static void cluster_info_copy(struct cluster_info *dst,
 }
 
 static enum sd_status cluster_wait_check(const struct sd_node *joining,
-					 const struct rb_root *nroot,
+					 const struct rb_node_root *nroot,
 					 size_t nr_nodes,
 					 struct cluster_info *cinfo)
 {
@@ -497,7 +497,7 @@ static void do_get_vdis(struct work *work)
 		return;
 	}
 
-	rb_for_each_entry(n, &w->nroot, rb) {
+	rb_for_each_entry(n, &w->nroot) {
 		/* We should not fetch vdi_bitmap and copy list from myself */
 		if (node_is_local(n))
 			continue;
@@ -529,7 +529,7 @@ static void get_vdis_done(struct work *work)
 	pthread_cond_broadcast(&wait_vdis_cond);
 	pthread_mutex_unlock(&wait_vdis_lock);
 
-	rb_destroy(&w->nroot, struct sd_node, rb);
+	rb_destroy(&w->nroot);
 	free(w);
 }
 
@@ -551,14 +551,14 @@ int inc_and_log_epoch(void)
 }
 
 static struct vnode_info *alloc_old_vnode_info(const struct sd_node *joined,
-					       const struct rb_root *nroot)
+					       const struct rb_node_root *nroot)
 {
-	struct rb_root old_root = RB_ROOT;
+	struct rb_node_root old_root = RB_ROOT_INITIALIZER(node_cmp);
 	struct sd_node *n;
 	struct vnode_info *old;
 
 	/* exclude the newly added one */
-	rb_for_each_entry(n, nroot, rb) {
+	rb_for_each_entry(n, nroot) {
 		struct sd_node *new = xmalloc(sizeof(*new));
 
 		*new = *n;
@@ -566,12 +566,12 @@ static struct vnode_info *alloc_old_vnode_info(const struct sd_node *joined,
 			free(new);
 			continue;
 		}
-		if (rb_insert(&old_root, new, rb, node_cmp))
+		if (rb_insert(new, &old_root))
 			panic("node hash collision");
 	}
 
 	old = alloc_vnode_info(&old_root);
-	rb_destroy(&old_root, struct sd_node, rb);
+	rb_destroy(&old_root);
 	return old;
 }
 
@@ -604,14 +604,15 @@ static void setup_backend_store(const struct cluster_info *cinfo)
 	}
 }
 
-static void get_vdis(const struct rb_root *nroot, const struct sd_node *joined)
+static void get_vdis(const struct rb_node_root *nroot,
+		     const struct sd_node *joined)
 {
 	struct get_vdis_work *w;
 
 	w = xmalloc(sizeof(*w));
 	w->joined = *joined;
-	INIT_RB_ROOT(&w->nroot);
-	rb_copy(nroot, struct sd_node, rb, &w->nroot, node_cmp);
+	INIT_RB_ROOT(&w->nroot, node_cmp);
+	rb_copy(nroot, &w->nroot);
 	refcount_inc(&nr_get_vdis_works);
 
 	w->work.fn = do_get_vdis;
@@ -633,7 +634,7 @@ void wait_get_vdis_done(void)
 
 static void update_cluster_info(const struct cluster_info *cinfo,
 				const struct sd_node *joined,
-				const struct rb_root *nroot,
+				const struct rb_node_root *nroot,
 				size_t nr_nodes)
 {
 	struct vnode_info *old_vnode_info;
@@ -735,7 +736,7 @@ main_fn void sd_notify_handler(const struct sd_node *sender, void *data,
  * cluster must call this function and succeed in accept of the joining node.
  */
 main_fn bool sd_join_handler(const struct sd_node *joining,
-			     const struct rb_root *nroot, size_t nr_nodes,
+			     const struct rb_node_root *nroot, size_t nr_nodes,
 			     void *opaque)
 {
 	struct cluster_info *cinfo = opaque;
@@ -882,8 +883,8 @@ static bool cluster_join_check(const struct cluster_info *cinfo)
 }
 
 main_fn void sd_accept_handler(const struct sd_node *joined,
-			       const struct rb_root *nroot, size_t nr_nodes,
-			       const void *opaque)
+			       const struct rb_node_root *nroot,
+			       size_t nr_nodes, const void *opaque)
 {
 	const struct cluster_info *cinfo = opaque;
 	struct sd_node *n;
@@ -896,7 +897,7 @@ main_fn void sd_accept_handler(const struct sd_node *joined,
 	cluster_info_copy(&sys->cinfo, cinfo);
 
 	sd_debug("join %s", node_to_str(joined));
-	rb_for_each_entry(n, nroot, rb) {
+	rb_for_each_entry(n, nroot) {
 		sd_debug("%s", node_to_str(n));
 	}
 
@@ -911,14 +912,14 @@ main_fn void sd_accept_handler(const struct sd_node *joined,
 }
 
 main_fn void sd_leave_handler(const struct sd_node *left,
-			      const struct rb_root *nroot, size_t nr_nodes)
+			      const struct rb_node_root *nroot, size_t nr_nodes)
 {
 	struct vnode_info *old_vnode_info;
 	struct sd_node *n;
 	int ret;
 
 	sd_debug("leave %s", node_to_str(left));
-	rb_for_each_entry(n, nroot, rb) {
+	rb_for_each_entry(n, nroot) {
 		sd_debug("%s", node_to_str(n));
 	}
 
@@ -947,7 +948,7 @@ main_fn void sd_leave_handler(const struct sd_node *left,
 static void update_node_size(struct sd_node *node)
 {
 	struct vnode_info *cur_vinfo = main_thread_get(current_vnode_info);
-	struct sd_node *n = rb_search(&cur_vinfo->nroot, node, rb, node_cmp);
+	struct sd_node *n = rb_search(node, &cur_vinfo->nroot);
 
 	if (unlikely(!n))
 		panic("can't find %s", node_to_str(node));
diff --git a/sheep/md.c b/sheep/md.c
index 9fc1b6e..94ac4d8 100644
--- a/sheep/md.c
+++ b/sheep/md.c
@@ -29,17 +29,27 @@ struct vdisk {
 	uint64_t hash;
 };
 
+static int disk_cmp(const struct disk *d1, const struct disk *d2)
+{
+	return strcmp(d1->path, d2->path);
+}
+
+static int vdisk_cmp(const struct vdisk *d1, const struct vdisk *d2)
+{
+	return intcmp(d1->hash, d2->hash);
+}
+
 struct md {
-	struct rb_root vroot;
-	struct rb_root root;
+	RB_ROOT(, struct vdisk, rb) vroot;
+	RB_ROOT(, struct disk, rb) root;
 	struct sd_lock lock;
 	uint64_t space;
 	uint32_t nr_disks;
 };
 
 static struct md md = {
-	.vroot = RB_ROOT,
-	.root = RB_ROOT,
+	.vroot = RB_ROOT_INITIALIZER(vdisk_cmp),
+	.root = RB_ROOT_INITIALIZER(disk_cmp),
 	.lock = SD_LOCK_INITIALIZER,
 };
 
@@ -59,19 +69,9 @@ static inline int vdisk_number(const struct disk *disk)
 	return DIV_ROUND_UP(disk->space, MD_VDISK_SIZE);
 }
 
-static int disk_cmp(const struct disk *d1, const struct disk *d2)
-{
-	return strcmp(d1->path, d2->path);
-}
-
-static int vdisk_cmp(const struct vdisk *d1, const struct vdisk *d2)
-{
-	return intcmp(d1->hash, d2->hash);
-}
-
 static struct vdisk *vdisk_insert(struct vdisk *new)
 {
-	return rb_insert(&md.vroot, new, rb, vdisk_cmp);
+	return rb_insert(new, &md.vroot);
 }
 
 /* If v1_hash < hval <= v2_hash, then oid is resident in v2 */
@@ -79,7 +79,7 @@ static struct vdisk *hval_to_vdisk(uint64_t hval)
 {
 	struct vdisk dummy = { .hash = hval };
 
-	return rb_nsearch(&md.vroot, &dummy, rb, vdisk_cmp);
+	return rb_nsearch(&dummy, &md.vroot);
 }
 
 static struct vdisk *oid_to_vdisk(uint64_t oid)
@@ -105,7 +105,7 @@ static void create_vdisks(struct disk *disk)
 
 static inline void vdisk_free(struct vdisk *v)
 {
-	rb_erase(&v->rb, &md.vroot);
+	rb_erase(v, &md.vroot);
 	free(v);
 }
 
@@ -139,7 +139,7 @@ static struct disk *path_to_disk(const char *path)
 	pstrcpy(key.path, sizeof(key.path), path);
 	trim_last_slash(key.path);
 
-	return rb_search(&md.root, &key, rb, disk_cmp);
+	return rb_search(&key, &md.root);
 }
 
 static int get_total_object_size(uint64_t oid, const char *wd, uint32_t epoch,
@@ -302,7 +302,7 @@ bool md_add_disk(const char *path, bool purge)
 	}
 
 	create_vdisks(new);
-	rb_insert(&md.root, new, rb, disk_cmp);
+	rb_insert(new, &md.root);
 	md.space += new->space;
 	md.nr_disks++;
 
@@ -314,7 +314,7 @@ bool md_add_disk(const char *path, bool purge)
 static inline void md_remove_disk(struct disk *disk)
 {
 	sd_info("%s from multi-disk array", disk->path);
-	rb_erase(&disk->rb, &md.root);
+	rb_erase(disk, &md.root);
 	md.nr_disks--;
 	remove_vdisks(disk);
 	free(disk);
@@ -355,7 +355,7 @@ int for_each_object_in_wd(int (*func)(uint64_t oid, const char *path,
 	const struct disk *disk;
 
 	sd_read_lock(&md.lock);
-	rb_for_each_entry(disk, &md.root, rb) {
+	rb_for_each_entry(disk, &md.root) {
 		ret = for_each_object_in_path(disk->path, func, cleanup, arg);
 		if (ret != SD_RES_SUCCESS)
 			break;
@@ -373,7 +373,7 @@ int for_each_object_in_stale(int (*func)(uint64_t oid, const char *path,
 	const struct disk *disk;
 
 	sd_read_lock(&md.lock);
-	rb_for_each_entry(disk, &md.root, rb) {
+	rb_for_each_entry(disk, &md.root) {
 		snprintf(path, sizeof(path), "%s/.stale", disk->path);
 		ret = for_each_object_in_path(path, func, false, arg);
 		if (ret != SD_RES_SUCCESS)
@@ -390,7 +390,7 @@ int for_each_obj_path(int (*func)(const char *path))
 	const struct disk *disk;
 
 	sd_read_lock(&md.lock);
-	rb_for_each_entry(disk, &md.root, rb) {
+	rb_for_each_entry(disk, &md.root) {
 		ret = func(disk->path);
 		if (ret != SD_RES_SUCCESS)
 			break;
@@ -547,7 +547,7 @@ static int scan_wd(uint64_t oid, uint32_t epoch)
 	const struct disk *disk;
 
 	sd_read_lock(&md.lock);
-	rb_for_each_entry(disk, &md.root, rb) {
+	rb_for_each_entry(disk, &md.root) {
 		ret = md_check_and_move(oid, epoch, disk->path);
 		if (ret == SD_RES_SUCCESS)
 			break;
@@ -597,7 +597,7 @@ uint32_t md_get_info(struct sd_md_info *info)
 
 	memset(info, 0, ret);
 	sd_read_lock(&md.lock);
-	rb_for_each_entry(disk, &md.root, rb) {
+	rb_for_each_entry(disk, &md.root) {
 		info->disk[i].idx = i;
 		pstrcpy(info->disk[i].path, PATH_MAX, disk->path);
 		/* FIXME: better handling failure case. */
@@ -669,7 +669,7 @@ uint64_t md_get_size(uint64_t *used)
 
 	*used = 0;
 	sd_read_lock(&md.lock);
-	rb_for_each_entry(disk, &md.root, rb) {
+	rb_for_each_entry(disk, &md.root) {
 		fsize += get_path_free_size(disk->path, used);
 	}
 	sd_unlock(&md.lock);
diff --git a/sheep/object_cache.c b/sheep/object_cache.c
index bd714f3..4b554cf 100644
--- a/sheep/object_cache.c
+++ b/sheep/object_cache.c
@@ -56,7 +56,8 @@ struct object_cache {
 	uint32_t dirty_count; /* How many dirty object in this cache */
 	uint32_t total_count; /* Count of objects include dirty and clean */
 	struct hlist_node hash; /* VDI is linked to the global hash lists */
-	struct rb_root lru_tree; /* For faster object search */
+	/* For faster object search */
+	RB_ROOT(rb_lru_root, struct object_cache_entry, node) lru_tree;
 	struct list_head lru_head; /* Per VDI LRU list for reclaimer */
 	struct list_head dirty_head; /* Dirty objects linked to this list */
 	int push_efd; /* Used to synchronize between pusher and push threads */
@@ -203,17 +204,17 @@ static inline void unlock_entry(struct object_cache_entry *entry)
 }
 
 static struct object_cache_entry *
-lru_tree_insert(struct rb_root *root, struct object_cache_entry *new)
+lru_tree_insert(struct rb_lru_root *root, struct object_cache_entry *new)
 {
-	return rb_insert(root, new, node, object_cache_cmp);
+	return rb_insert(new, root);
 }
 
-static struct object_cache_entry *lru_tree_search(struct rb_root *root,
+static struct object_cache_entry *lru_tree_search(struct rb_lru_root *root,
 						  uint32_t idx)
 {
 	struct object_cache_entry key = { .idx = idx };
 
-	return rb_search(root, &key, node, object_cache_cmp);
+	return rb_search(&key, root);
 }
 
 static void do_background_push(struct work *work)
@@ -270,7 +271,7 @@ free_cache_entry(struct object_cache_entry *entry)
 {
 	struct object_cache *oc = entry->oc;
 
-	rb_erase(&entry->node, &oc->lru_tree);
+	rb_erase(entry, &oc->lru_tree);
 	list_del(&entry->lru_list);
 	oc->total_count--;
 	if (list_linked(&entry->dirty_list))
@@ -624,7 +625,7 @@ not_found:
 	if (create) {
 		cache = xzalloc(sizeof(*cache));
 		cache->vid = vid;
-		INIT_RB_ROOT(&cache->lru_tree);
+		INIT_RB_ROOT(&cache->lru_tree, object_cache_cmp);
 		create_dir_for(vid);
 		cache->push_efd = eventfd(0, 0);
 
diff --git a/sheep/object_list_cache.c b/sheep/object_list_cache.c
index caba3ce..78bb852 100644
--- a/sheep/object_list_cache.c
+++ b/sheep/object_list_cache.c
@@ -18,12 +18,18 @@ struct objlist_cache_entry {
 	struct rb_node node;
 };
 
+static int objlist_cache_cmp(const struct objlist_cache_entry *a,
+			     const struct objlist_cache_entry *b)
+{
+	return intcmp(a->oid, b->oid);
+}
+
 struct objlist_cache {
 	int tree_version;
 	int buf_version;
 	int cache_size;
 	uint64_t *buf;
-	struct rb_root root;
+	RB_ROOT(rb_objlist_root, struct objlist_cache_entry, node) root;
 	struct sd_lock lock;
 };
 
@@ -34,31 +40,26 @@ struct objlist_deletion_work {
 
 static struct objlist_cache obj_list_cache = {
 	.tree_version	= 1,
-	.root		= RB_ROOT,
+	.root		= RB_ROOT_INITIALIZER(objlist_cache_cmp),
 	.lock		= SD_LOCK_INITIALIZER,
 };
 
-static int objlist_cache_cmp(const struct objlist_cache_entry *a,
-			     const struct objlist_cache_entry *b)
-{
-	return intcmp(a->oid, b->oid);
-}
-
-static struct objlist_cache_entry *objlist_cache_rb_insert(struct rb_root *root,
-		struct objlist_cache_entry *new)
+static struct objlist_cache_entry *objlist_cache_rb_insert(
+	struct rb_objlist_root *root,
+	struct objlist_cache_entry *new)
 {
-	return rb_insert(root, new, node, objlist_cache_cmp);
+	return rb_insert(new, root);
 }
 
-static int objlist_cache_rb_remove(struct rb_root *root, uint64_t oid)
+static int objlist_cache_rb_remove(struct rb_objlist_root *root, uint64_t oid)
 {
-	struct objlist_cache_entry *entry,  key = { .oid = oid  };
+	struct objlist_cache_entry *entry,  key = { .oid = oid };
 
-	entry = rb_search(root, &key, node, objlist_cache_cmp);
+	entry = rb_search(&key, root);
 	if (!entry)
 		return -1;
 
-	rb_erase(&entry->node, root);
+	rb_erase(entry, root);
 	free(entry);
 
 	return 0;
@@ -115,7 +116,7 @@ int get_obj_list(const struct sd_req *hdr, struct sd_rsp *rsp, void *data)
 	obj_list_cache.buf = xrealloc(obj_list_cache.buf,
 				obj_list_cache.cache_size * sizeof(uint64_t));
 
-	rb_for_each_entry(entry, &obj_list_cache.root, node) {
+	rb_for_each_entry(entry, &obj_list_cache.root) {
 		obj_list_cache.buf[nr++] = entry->oid;
 	}
 
@@ -153,7 +154,7 @@ static void objlist_deletion_work(struct work *work)
 	}
 
 	sd_write_lock(&obj_list_cache.lock);
-	rb_for_each_entry(entry, &obj_list_cache.root, node) {
+	rb_for_each_entry(entry, &obj_list_cache.root) {
 		entry_vid = oid_to_vid(entry->oid);
 		if (entry_vid != vid)
 			continue;
@@ -163,7 +164,7 @@ static void objlist_deletion_work(struct work *work)
 			continue;
 
 		sd_debug("delete object entry %" PRIx64, entry->oid);
-		rb_erase(&entry->node, &obj_list_cache.root);
+		rb_erase(entry, &obj_list_cache.root);
 		free(entry);
 	}
 	sd_unlock(&obj_list_cache.lock);
diff --git a/sheep/ops.c b/sheep/ops.c
index 7fdb351..05ba009 100644
--- a/sheep/ops.c
+++ b/sheep/ops.c
@@ -534,7 +534,7 @@ static int cluster_force_recover_main(const struct sd_req *req,
 	int ret = SD_RES_SUCCESS;
 	struct sd_node *nodes = data;
 	size_t nr_nodes = rsp->data_length / sizeof(*nodes);
-	struct rb_root nroot = RB_ROOT;
+	struct rb_node_root nroot = RB_ROOT_INITIALIZER(node_cmp);
 
 	if (rsp->epoch != sys->cinfo.epoch) {
 		sd_err("epoch was incremented while cluster_force_recover");
@@ -554,7 +554,7 @@ static int cluster_force_recover_main(const struct sd_req *req,
 	sys->cinfo.status = SD_STATUS_OK;
 
 	for (int i = 0; i < nr_nodes; i++)
-		rb_insert(&nroot, &nodes[i], rb, node_cmp);
+		rb_insert(&nodes[i], &nroot);
 
 	vnode_info = get_vnode_info();
 	old_vnode_info = alloc_vnode_info(&nroot);
@@ -657,8 +657,7 @@ static int cluster_recovery_completion(const struct sd_req *req,
 
 	if (vnode_info->nr_nodes == nr_recovereds) {
 		for (i = 0; i < nr_recovereds; ++i) {
-			if (!rb_search(&vnode_info->nroot, &recovereds[i],
-				       rb, node_cmp))
+			if (!rb_search(&recovereds[i], &vnode_info->nroot))
 				break;
 		}
 		if (i == nr_recovereds) {
diff --git a/sheep/recovery.c b/sheep/recovery.c
index 0df3a5a..36c6251 100644
--- a/sheep/recovery.c
+++ b/sheep/recovery.c
@@ -164,7 +164,7 @@ static int recover_object_from(struct recovery_obj_work *row,
 static bool invalid_node(const struct sd_node *n, struct vnode_info *info)
 {
 
-	if (rb_search(&info->nroot, n, rb, node_cmp))
+	if (rb_search(n, &info->nroot))
 		return false;
 	return true;
 }
diff --git a/sheep/sheep_priv.h b/sheep/sheep_priv.h
index 588a61c..5a178ce 100644
--- a/sheep/sheep_priv.h
+++ b/sheep/sheep_priv.h
@@ -299,7 +299,7 @@ int local_get_node_list(const struct sd_req *req, struct sd_rsp *rsp,
 struct vnode_info *grab_vnode_info(struct vnode_info *vnode_info);
 struct vnode_info *get_vnode_info(void);
 void put_vnode_info(struct vnode_info *vinfo);
-struct vnode_info *alloc_vnode_info(const struct rb_root *);
+struct vnode_info *alloc_vnode_info(const struct rb_node_root *);
 struct vnode_info *get_vnode_info_epoch(uint32_t epoch,
 					struct vnode_info *cur_vinfo);
 void wait_get_vdis_done(void);
diff --git a/sheep/vdi.c b/sheep/vdi.c
index e46e3e7..64f4f9c 100644
--- a/sheep/vdi.c
+++ b/sheep/vdi.c
@@ -18,27 +18,28 @@ struct vdi_state_entry {
 	struct rb_node node;
 };
 
-static struct rb_root vdi_state_root = RB_ROOT;
-static struct sd_lock vdi_state_lock = SD_LOCK_INITIALIZER;
-
 static int vdi_state_cmp(const struct vdi_state_entry *a,
 			 const struct vdi_state_entry *b)
 {
 	return intcmp(a->vid, b->vid);
 }
 
-static struct vdi_state_entry *vdi_state_search(struct rb_root *root,
+static RB_ROOT(rb_vdi_state_root, struct vdi_state_entry, node) vdi_state_root =
+	RB_ROOT_INITIALIZER(vdi_state_cmp);
+static struct sd_lock vdi_state_lock = SD_LOCK_INITIALIZER;
+
+static struct vdi_state_entry *vdi_state_search(struct rb_vdi_state_root *root,
 						uint32_t vid)
 {
 	struct vdi_state_entry key = { .vid = vid };
 
-	return rb_search(root, &key, node, vdi_state_cmp);
+	return rb_search(&key, root);
 }
 
-static struct vdi_state_entry *vdi_state_insert(struct rb_root *root,
+static struct vdi_state_entry *vdi_state_insert(struct rb_vdi_state_root *root,
 						struct vdi_state_entry *new)
 {
-	return rb_insert(root, new, node, vdi_state_cmp);
+	return rb_insert(new, root);
 }
 
 static bool vid_is_snapshot(uint32_t vid)
@@ -132,7 +133,7 @@ int fill_vdi_state_list(void *data)
 	struct vdi_state_entry *entry;
 
 	sd_read_lock(&vdi_state_lock);
-	rb_for_each_entry(entry, &vdi_state_root, node) {
+	rb_for_each_entry(entry, &vdi_state_root) {
 		memset(vs, 0, sizeof(*vs));
 		vs->vid = entry->vid;
 		vs->nr_copies = entry->nr_copies;
@@ -1105,7 +1106,7 @@ out:
 void clean_vdi_state(void)
 {
 	sd_write_lock(&vdi_state_lock);
-	rb_destroy(&vdi_state_root, struct vdi_state_entry, node);
-	INIT_RB_ROOT(&vdi_state_root);
+	rb_destroy(&vdi_state_root);
+	INIT_RB_ROOT(&vdi_state_root, vdi_state_cmp);
 	sd_unlock(&vdi_state_lock);
 }
diff --git a/sheepfs/volume.c b/sheepfs/volume.c
index ca8925a..a08cde6 100644
--- a/sheepfs/volume.c
+++ b/sheepfs/volume.c
@@ -61,24 +61,25 @@ struct vdi_inode {
 	unsigned socket_poll_adder;
 };
 
-static struct rb_root vdi_inode_tree = RB_ROOT;
-static struct sd_lock vdi_inode_tree_lock = SD_LOCK_INITIALIZER;
-
 static int vdi_inode_cmp(const struct vdi_inode *a, const struct vdi_inode *b)
 {
 	return intcmp(a->vid, b->vid);
 }
 
+static RB_ROOT(, struct vdi_inode, rb) vdi_inode_tree =
+	RB_ROOT_INITIALIZER(vdi_inode_cmp);
+static struct sd_lock vdi_inode_tree_lock = SD_LOCK_INITIALIZER;
+
 static struct vdi_inode *vdi_inode_tree_insert(struct vdi_inode *new)
 {
-	return rb_insert(&vdi_inode_tree, new, rb, vdi_inode_cmp);
+	return rb_insert(new, &vdi_inode_tree);
 }
 
 static struct vdi_inode *vdi_inode_tree_search(uint32_t vid)
 {
 	struct vdi_inode key = { .vid = vid };
 
-	return rb_search(&vdi_inode_tree, &key, rb, vdi_inode_cmp);
+	return rb_search(&key, &vdi_inode_tree);
 }
 
 int create_volume_layout(void)
@@ -349,7 +350,7 @@ int reset_socket_pool(void)
 	int ret = 0;
 
 	sd_read_lock(&vdi_inode_tree_lock);
-	rb_for_each_entry(vdi, &vdi_inode_tree, rb) {
+	rb_for_each_entry(vdi, &vdi_inode_tree) {
 		destroy_socket_pool(vdi->socket_pool, SOCKET_POOL_SIZE);
 		if (setup_socket_pool(vdi->socket_pool,
 			SOCKET_POOL_SIZE) < 0) {
@@ -400,7 +401,7 @@ static int init_vdi_info(const char *entry, uint32_t *vid, size_t *size)
 		goto err;
 	if (volume_rw_object(inode_buf, vid_to_vdi_oid(*vid), SD_INODE_SIZE,
 			     0, VOLUME_READ) < 0) {
-		rb_erase(&inode->rb, &vdi_inode_tree);
+		rb_erase(inode, &vdi_inode_tree);
 		sheepfs_pr("failed to read inode for %"PRIx32"\n", *vid);
 		goto err;
 	}
@@ -502,7 +503,7 @@ int volume_remove_entry(const char *entry)
 	destroy_socket_pool(vdi->socket_pool, SOCKET_POOL_SIZE);
 
 	sd_write_lock(&vdi_inode_tree_lock);
-	rb_erase(&vdi->rb, &vdi_inode_tree);
+	rb_erase(vdi, &vdi_inode_tree);
 	sd_unlock(&vdi_inode_tree_lock);
 
 	free(vdi->inode);
diff --git a/tests/unit/dog/mock_dog.c b/tests/unit/dog/mock_dog.c
index 80eeb0e..e11d86a 100644
--- a/tests/unit/dog/mock_dog.c
+++ b/tests/unit/dog/mock_dog.c
@@ -22,8 +22,8 @@ struct node_id sd_nid = {
 };
 bool highlight = true;
 bool raw_output;
-struct rb_root sd_vroot = RB_ROOT;
-struct rb_root sd_nroot = RB_ROOT;
+struct rb_vnode_root sd_vroot = RB_ROOT_INITIALIZER(vnode_cmp);
+struct rb_node_root sd_nroot = RB_ROOT_INITIALIZER(node_cmp);
 
 MOCK_METHOD(update_node_list, int, 0, int max_nodes)
 MOCK_VOID_METHOD(subcommand_usage, char *cmd, char *subcmd, int status)
diff --git a/tests/unit/mock/mock.c b/tests/unit/mock/mock.c
index 373c02a..9c2660d 100644
--- a/tests/unit/mock/mock.c
+++ b/tests/unit/mock/mock.c
@@ -17,13 +17,13 @@
 
 #include "mock.h"
 
-struct rb_root mock_methods = RB_ROOT;
+struct rb_mock_tree mock_methods = RB_ROOT_INITIALIZER(mock_cmp);
 
 static struct mock_method *find_method(const char *name)
 {
 	struct mock_method key = { .name = name };
 
-	return rb_search(&mock_methods, &key, rb, mock_cmp);
+	return rb_search(&key, &mock_methods);
 }
 
 int __method_nr_call(const char *name)
@@ -40,6 +40,6 @@ int __method_nr_call(const char *name)
 void __method_reset_all(void)
 {
 	struct mock_method *method;
-	rb_for_each_entry(method, &mock_methods, rb)
+	rb_for_each_entry(method, &mock_methods)
 		method->nr_call = 0;
 }
diff --git a/tests/unit/mock/mock.h b/tests/unit/mock/mock.h
index c2d6f5b..42f4a89 100644
--- a/tests/unit/mock/mock.h
+++ b/tests/unit/mock/mock.h
@@ -30,11 +30,13 @@ static inline int mock_cmp(const struct mock_method *m1,
 	return strcmp(m1->name, m2->name);
 }
 
-extern struct rb_root mock_methods;
+RB_ROOT(rb_mock_tree, struct mock_method, rb);
+
+extern struct rb_mock_tree mock_methods;
 #define method_register(m)						\
 	static void __attribute__((constructor)) regist_##m(void)	\
 	{								\
-		rb_insert(&mock_methods, &m, rb, mock_cmp);		\
+		rb_insert(&m, &mock_methods);				\
 	}
 
 #define MOCK_VOID_METHOD(m, ...)			\
diff --git a/tests/unit/sheep/mock_group.c b/tests/unit/sheep/mock_group.c
index 27e9b6e..556183a 100644
--- a/tests/unit/sheep/mock_group.c
+++ b/tests/unit/sheep/mock_group.c
@@ -17,13 +17,13 @@
 #include "cluster.h"
 
 MOCK_VOID_METHOD(sd_accept_handler, const struct sd_node *joined,
-		 const struct rb_root *nroot, size_t nr_nodes,
+		 const struct rb_node_root *nroot, size_t nr_nodes,
 		 const void *opaque)
 MOCK_METHOD(sd_join_handler, bool, true, const struct sd_node *joining,
-	    const struct rb_root *nroot, size_t nr_nodes,
+	    const struct rb_node_root *nroot, size_t nr_nodes,
 	    void *opaque)
 MOCK_VOID_METHOD(sd_leave_handler, const struct sd_node *left,
-		 const struct rb_root *nroot, size_t nr_nodes)
+		 const struct rb_node_root *nroot, size_t nr_nodes)
 MOCK_VOID_METHOD(sd_notify_handler, const struct sd_node *sender, void *msg,
 		 size_t msg_len)
 MOCK_METHOD(sd_block_handler, bool, true, const struct sd_node *sender)
diff --git a/tests/unit/sheep/test_hash.c b/tests/unit/sheep/test_hash.c
index 26386b9..02aac37 100644
--- a/tests/unit/sheep/test_hash.c
+++ b/tests/unit/sheep/test_hash.c
@@ -230,12 +230,13 @@ static void node5_setup(void)
 	gen_nodes = gen_many_nodes_some_vnodes;
 }
 
-static size_t get_vnodes_array(struct rb_root *vroot, struct sd_vnode *vnodes)
+static size_t get_vnodes_array(const struct rb_vnode_root *vroot,
+			       struct sd_vnode *vnodes)
 {
 	struct sd_vnode *vnode;
 	size_t nr = 0;
 
-	rb_for_each_entry(vnode, vroot, rb) {
+	rb_for_each_entry(vnode, vroot) {
 		nr++;
 		*vnodes++ = *vnode;
 	}
@@ -251,11 +252,11 @@ START_TEST(test_nodes_update)
 	struct sd_node nodes[DATA_SIZE];
 	struct sd_vnode vnodes[DATA_SIZE];
 	struct sd_vnode vnodes_after[DATA_SIZE];
-	struct rb_root vroot;
+	struct rb_vnode_root vroot;
 
 	gen_nodes(nodes, 0);
 
-	INIT_RB_ROOT(&vroot);
+	INIT_RB_ROOT(&vroot, vnode_cmp);
 	node_to_vnodes(nodes, &vroot);
 	nr_vnodes = get_vnodes_array(&vroot, vnodes);
 	/* 1 node join */
@@ -264,7 +265,7 @@ START_TEST(test_nodes_update)
 	ck_assert(is_subset(vnodes_after, nr_vnodes_after, vnodes,
 			    nr_vnodes, vnode_cmp));
 
-	INIT_RB_ROOT(&vroot);
+	INIT_RB_ROOT(&vroot, vnode_cmp);
 	for (int i = 0; i < 100; i++)
 		node_to_vnodes(nodes + i, &vroot);
 	nr_vnodes = get_vnodes_array(&vroot, vnodes);
@@ -280,30 +281,30 @@ START_TEST(test_nodes_update)
 	ck_assert(is_subset(vnodes_after, nr_vnodes_after, vnodes,
 			    nr_vnodes, vnode_cmp));
 
-	INIT_RB_ROOT(&vroot);
+	INIT_RB_ROOT(&vroot, vnode_cmp);
 	node_to_vnodes(nodes, &vroot);
 	node_to_vnodes(nodes + 1, &vroot);
 	nr_vnodes = get_vnodes_array(&vroot, vnodes);
 	/* 1 node leave */
-	INIT_RB_ROOT(&vroot);
+	INIT_RB_ROOT(&vroot, vnode_cmp);
 	node_to_vnodes(nodes, &vroot);
 	nr_vnodes_after = get_vnodes_array(&vroot, vnodes_after);
 	ck_assert(is_subset(vnodes, nr_vnodes, vnodes_after,
 			    nr_vnodes_after, vnode_cmp));
 
-	INIT_RB_ROOT(&vroot);
+	INIT_RB_ROOT(&vroot, vnode_cmp);
 	for (int i = 0; i < 200; i++)
 		node_to_vnodes(nodes + i, &vroot);
 	nr_vnodes = get_vnodes_array(&vroot, vnodes);
 	/* 1 node leave */
-	INIT_RB_ROOT(&vroot);
+	INIT_RB_ROOT(&vroot, vnode_cmp);
 	for (int i = 0; i < 199; i++)
 		node_to_vnodes(nodes + i, &vroot);
 	nr_vnodes_after = get_vnodes_array(&vroot, vnodes_after);
 	ck_assert(is_subset(vnodes, nr_vnodes, vnodes_after,
 			    nr_vnodes_after, vnode_cmp));
 	/* 100 nodes leave */
-	INIT_RB_ROOT(&vroot);
+	INIT_RB_ROOT(&vroot, vnode_cmp);
 	for (int i = 50; i < 150; i++)
 		node_to_vnodes(nodes + i, &vroot);
 	nr_vnodes_after = get_vnodes_array(&vroot, vnodes_after);
@@ -316,16 +317,16 @@ static void gen_data_from_nodes(double *data, int idx)
 {
 	struct sd_node nodes[DATA_SIZE];
 	struct sd_vnode *vnode;
-	struct rb_root vroot;
+	struct rb_vnode_root vroot;
 	int nr_nodes;
 	double *p = data;
 
 	nr_nodes = gen_nodes(nodes, idx);
-	INIT_RB_ROOT(&vroot);
+	INIT_RB_ROOT(&vroot, vnode_cmp);
 	for (int i = 0; i < nr_nodes; i++)
 		node_to_vnodes(nodes + i, &vroot);
 
-	rb_for_each_entry(vnode, &vroot, rb)
+	rb_for_each_entry(vnode, &vroot)
 		*p++ = vnode->hash;
 
 	ck_assert_int_eq(p - data, DATA_SIZE);
@@ -398,7 +399,7 @@ static size_t get_vdisks_array(struct vdisk *vdisks)
 	struct vdisk *vdisk;
 	size_t nr = 0;
 
-	rb_for_each_entry(vdisk, &md.vroot, rb) {
+	rb_for_each_entry(vdisk, &md.vroot) {
 		nr++;
 		*vdisks++ = *vdisk;
 	}
@@ -416,7 +417,7 @@ START_TEST(test_disks_update)
 
 	gen_disks(disks, 0);
 
-	INIT_RB_ROOT(&md.vroot);
+	INIT_RB_ROOT(&md.vroot, vdisk_cmp);
 	create_vdisks(disks);
 	nr_vdisks = get_vdisks_array(vdisks);
 	/* add 1 disk */
@@ -425,7 +426,7 @@ START_TEST(test_disks_update)
 	ck_assert(is_subset(vdisks_after, nr_vdisks_after, vdisks,
 			    nr_vdisks, vdisk_cmp));
 
-	INIT_RB_ROOT(&md.vroot);
+	INIT_RB_ROOT(&md.vroot, vdisk_cmp);
 	for (int i = 0; i < 30; i++)
 		create_vdisks(disks + i);
 	nr_vdisks = get_vdisks_array(vdisks);
@@ -441,7 +442,7 @@ START_TEST(test_disks_update)
 	ck_assert(is_subset(vdisks_after, nr_vdisks_after, vdisks,
 			    nr_vdisks, vdisk_cmp));
 
-	INIT_RB_ROOT(&md.vroot);
+	INIT_RB_ROOT(&md.vroot, vdisk_cmp);
 	create_vdisks(disks);
 	create_vdisks(disks + 1);
 	nr_vdisks = get_vdisks_array(vdisks);
@@ -451,7 +452,7 @@ START_TEST(test_disks_update)
 	ck_assert(is_subset(vdisks, nr_vdisks, vdisks_after,
 			    nr_vdisks_after, vdisk_cmp));
 
-	INIT_RB_ROOT(&md.vroot);
+	INIT_RB_ROOT(&md.vroot, vdisk_cmp);
 	for (int i = 0; i < 50; i++)
 		create_vdisks(disks + i);
 	nr_vdisks = get_vdisks_array(vdisks);
@@ -479,11 +480,11 @@ static void gen_data_from_disks(double *data, int idx)
 	double *p = data;
 
 	nr_disks = gen_disks(disks, idx);
-	INIT_RB_ROOT(&md.vroot);
+	INIT_RB_ROOT(&md.vroot, vdisk_cmp);
 	for (int i = 0; i < nr_disks; i++)
 		create_vdisks(disks + i);
 
-	rb_for_each_entry(vdisk, &md.vroot, rb)
+	rb_for_each_entry(vdisk, &md.vroot)
 		*p++ = vdisk->hash;
 
 	ck_assert_int_eq(p - data, DATA_SIZE);
-- 
1.8.1.2




More information about the sheepdog mailing list