23 April 2014

What to do when you get "Comparison method violates its general contract!"

Java 7 switched to a new sort algorithm - called ComparableTimSort. The old algorithm would silently ignore when you violate the Comparator/Comparable contract. The new one throws an exception "Comparison method violates its general contract!"

The contract is in the interface Comparable. The most common is to break the transitive contract. That is: If A < B and B < C, than A should be less than C. Stackoverflow.com has several posts on the issue.

I have made a small program that detects what elements that causes this issue. It only works for small lists but has helped me to debug compareTo()/compare() methods.

package se.lesc.blog;

import java.util.Comparator;

public class TransitivityTester {

    public static void transitivityTester(Object list[]) {
        transitivityTester(list, null);
    }
    
    /** Check that a lists compareTo() method is correct with respect to transitivity */    
    public static void transitivityTester(Object list[], Comparator<Object> comparator) {

        //Cache all possible comparison (for speed)
        byte[][] compareTable = new byte[list.length][];
        for (int x = 0; x < list.length; x++) {
            byte[] column = new byte[list.length];
            compareTable[x] = column;
            for (int y = 0; y < list.length; y++) {
                int result;
                if (comparator != null) {
                    result = comparator.compare(list[x], list[y]);    
                } else {
                    @SuppressWarnings("unchecked")
                    Comparable<Object> comparableX = (Comparable<Object>)list[x];
                    result = comparableX.compareTo(list[y]);
                }
                
                byte normalizedResult = normalize(result);
                column[y] = normalizedResult; 
            }
        }
        
        //Expensive O(n^3) iteration
        for (int a = 0; a < list.length; a++) {
            for (int b = 0; b < list.length; b++) {
                for (int c = 0; c < list.length; c++) {

                    if (compareTable[a][b] < 0 && compareTable[b][c] < 0) {
                        if (! (compareTable[a][c] < 0)) {
                            transitiveError("A < B && B < C but not A < C", a, b, c);
                        }
                    } else if (compareTable[a][b] > 0 && compareTable[b][c] > 0) {
                        if (! (compareTable[a][c] > 0)) {
                            transitiveError("A > B && B > C but not A > C", a, b, c);
                        }
                    } else if (compareTable[a][b] == 0 && compareTable[b][c] == 0) {
                        if (! (compareTable[a][c] == 0)) {
                            transitiveError("A == B && B == C but not A == C", a, b, c);
                        }
                    }
                }
            }
        }
    }
    
    private static void transitiveError(String transitiveRule, int a, int b, int c) {
        String errorMessage =
                        transitiveRule +
                        " (A = " + a + ", B= " + b + ", C = " + c + ") " + 
                        "Comparison method violates its general contract!";
        throw new IllegalArgumentException(errorMessage);
    }

    private static byte normalize(int result) {
        byte normalizedResult;
        if (result > 0) {
            normalizedResult = 1;
        } else if (result < 0) {
            normalizedResult = -1;
        } else {
            normalizedResult = 0;
        }
        return normalizedResult;
    }
}

Here is some test code for the above class:
package se.lesc.blog;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

import org.junit.Before;
import org.junit.Test;

public class TransitivityTesterTest {

    private List<Integer> list;

    @Before
    public void setUp() {
        list = new ArrayList<Integer>();
        list.add(1);
        list.add(2);
        list.add(3);
    }

    @Test
    public void testNormalIntegerListShouldWork() {
        TransitivityTester.transitivityTester(list.toArray(new Integer[0]));        
    }
    
    @Test
    public void testCrazyOneIsAlwaysMoreComparator() {
        TransitivityTester.transitivityTester(list.toArray(new Integer[0]), new CrazyOneIsAlwaysMoreComparator());        
    }
    
    @Test
    public void testCrazyAllIsOneComparator() {
        TransitivityTester.transitivityTester(list.toArray(new Integer[0]), new CrazyAllIsOneComparator());        
    }
    
    public static class CrazyOneIsAlwaysMoreComparator implements Comparator<Object> {
        @Override
        public int compare(Object o1, Object o2) {
            Integer i1 = (Integer) o1;
            Integer i2 = (Integer) o2;
            if (i1.intValue() == 1) {
                return 1; 
            } else {
                return i1.compareTo(i2);
            }
        }
    }
    
    public static class CrazyAllIsOneComparator implements Comparator<Object> {
        @Override
        public int compare(Object o1, Object o2) {
            Integer i1 = (Integer) o1;
            Integer i2 = (Integer) o2;
            if (i1.intValue() == 1 || i2.intValue() == 1) {
                return 0; 
            } else {
                return i1.compareTo(i2);
            }
        }
    }
}