/* splay.c $Id: splay.c,v 1.1 2006/12/22 04:15:08 dmochiha Exp $ */ #include #include #include "splay.h" static splay * splay_tree (splay *t, void *i, cmpfunc comp) { splay n, *l, *r, *y; if (t == NULL) return t; n.left = n.right = NULL; l = r = &n; for (;;) { if ((*comp)(i, t->item) < 0) { if (t->left == NULL) break; if ((*comp)(i, t->left->item) < 0) { /* rotate */ y = t->left; t->left = y->right; y->right = t; t = y; if (t->left == NULL) break; } r->left = t; /* link right */ r = t; t = t->left; } else if ((*comp)(i, t->item) > 0) { if (t->right == NULL) break; if ((*comp)(i, t->right->item) > 0) { /* rotate */ y = t->right; t->right = y->left; y->left = t; t = y; if (t->right == NULL) break; } l->right = t; /* link left */ l = t; t = t->right; } else break; } l->right = t->left; r->left = t->right; t->left = n.right; t->right = n.left; return t; } splay * splay_insert (splay *t, void *i, cmpfunc comp) { splay *n; if ((n = (splay *)malloc(sizeof(splay))) == NULL) { perror("malloc"); exit(1); } n->item = i; if (t == NULL) { n->left = n->right = NULL; return n; } t = splay_tree (t, i, comp); if ((*comp)(i, t->item) < 0) { n->left = t->left; n->right = t; t->left = NULL; return n; } else if ((*comp)(i, t->item) > 0) { n->right = t->right; n->left = t; t->right = NULL; return n; } else { /* already in the tree */ free(n); return t; } } splay * splay_delete (splay *t, void *i, cmpfunc comp) { splay *x; if (t == NULL) return NULL; t = splay_tree (t, i, comp); if ((*comp)(i, t->item) == 0) { /* found */ if (t->left == NULL) { x = t->right; } else { x = splay_tree (t->left, i, comp); x->right = t->right; } free(t); return x; } /* not found */ return t; } void * splay_find (splay **tp, void *i, cmpfunc comp) /* NULL -> fail, else success */ { if (*tp == NULL) return NULL; *tp = splay_tree (*tp, i, comp); if ((*comp)(i, (*tp)->item) == 0) /* found */ return (*tp)->item; else return NULL; } splay * splay_free (splay *t) { if (t == NULL) return NULL; else { splay_free (t->left); splay_free (t->right); free(t); return NULL; } } void _splay_print (splay *t, prfunc print, int n) { int i; if (t == NULL) return; else { _splay_print (t->left, print, n + 1); for (i = 0; i < n; i++) printf(" "); (*print)(t->item); _splay_print (t->right, print, n + 1); } } void splay_print (splay *t, prfunc print) { _splay_print(t, print, 0); } int numcmp (void *a, void *b) { return ((int)a - (int)b); } int idcmp (void *a, void *b) { return (a - b); } int intprint (void *i) { return printf("%d\n", (int)i); }