(Assert|Check)DeepEquals now show diffs!
diff --git a/testhelper/convenience.go b/testhelper/convenience.go
index f6cb371..ca27cad 100644
--- a/testhelper/convenience.go
+++ b/testhelper/convenience.go
@@ -5,11 +5,12 @@
"path/filepath"
"reflect"
"runtime"
+ "strings"
"testing"
)
-func prefix() string {
- _, file, line, _ := runtime.Caller(3)
+func prefix(depth int) string {
+ _, file, line, _ := runtime.Caller(depth)
return fmt.Sprintf("Failure in %s, line %d:", filepath.Base(file), line)
}
@@ -22,11 +23,182 @@
}
func logFatal(t *testing.T, str string) {
- t.Fatalf("\033[1;31m%s %s\033[0m", prefix(), str)
+ t.Fatalf("\033[1;31m%s %s\033[0m", prefix(3), str)
}
func logError(t *testing.T, str string) {
- t.Errorf("\033[1;31m%s %s\033[0m", prefix(), str)
+ t.Errorf("\033[1;31m%s %s\033[0m", prefix(3), str)
+}
+
+type diffLogger func([]string, interface{}, interface{})
+
+type visit struct {
+ a1 uintptr
+ a2 uintptr
+ typ reflect.Type
+}
+
+// Recursively visits the structures of "expected" and "actual". The diffLogger function will be
+// invoked with each different value encountered, including the reference path that was followed
+// to get there.
+func deepDiffEqual(expected, actual reflect.Value, visited map[visit]bool, path []string, logDifference diffLogger) {
+ defer func() {
+ // Fall back to the regular reflect.DeepEquals function.
+ if r := recover(); r != nil {
+ var e, a interface{}
+ if expected.IsValid() {
+ e = expected.Interface()
+ }
+ if actual.IsValid() {
+ a = actual.Interface()
+ }
+
+ if !reflect.DeepEqual(e, a) {
+ logDifference(path, e, a)
+ }
+ }
+ }()
+
+ if !expected.IsValid() && actual.IsValid() {
+ logDifference(path, nil, actual.Interface())
+ return
+ }
+ if expected.IsValid() && !actual.IsValid() {
+ logDifference(path, expected.Interface(), nil)
+ return
+ }
+ if !expected.IsValid() && !actual.IsValid() {
+ return
+ }
+
+ hard := func(k reflect.Kind) bool {
+ switch k {
+ case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
+ return true
+ }
+ return false
+ }
+
+ if expected.CanAddr() && actual.CanAddr() && hard(expected.Kind()) {
+ addr1 := expected.UnsafeAddr()
+ addr2 := actual.UnsafeAddr()
+
+ if addr1 > addr2 {
+ addr1, addr2 = addr2, addr1
+ }
+
+ if addr1 == addr2 {
+ // References are identical. We can short-circuit
+ return
+ }
+
+ typ := expected.Type()
+ v := visit{addr1, addr2, typ}
+ if visited[v] {
+ // Already visited.
+ return
+ }
+
+ // Remember this visit for later.
+ visited[v] = true
+ }
+
+ switch expected.Kind() {
+ case reflect.Array:
+ for i := 0; i < expected.Len(); i++ {
+ hop := append(path, fmt.Sprintf("[%d]", i))
+ deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
+ }
+ return
+ case reflect.Slice:
+ if expected.IsNil() != actual.IsNil() {
+ logDifference(path, expected.Interface(), actual.Interface())
+ return
+ }
+ if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
+ return
+ }
+ for i := 0; i < expected.Len(); i++ {
+ hop := append(path, fmt.Sprintf("[%d]", i))
+ deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
+ }
+ return
+ case reflect.Interface:
+ if expected.IsNil() != actual.IsNil() {
+ logDifference(path, expected.Interface(), actual.Interface())
+ return
+ }
+ deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
+ return
+ case reflect.Ptr:
+ deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
+ return
+ case reflect.Struct:
+ for i, n := 0, expected.NumField(); i < n; i++ {
+ field := expected.Type().Field(i)
+ hop := append(path, "."+field.Name)
+ deepDiffEqual(expected.Field(i), actual.Field(i), visited, hop, logDifference)
+ }
+ return
+ case reflect.Map:
+ if expected.IsNil() != actual.IsNil() {
+ logDifference(path, expected.Interface(), actual.Interface())
+ return
+ }
+ if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
+ return
+ }
+
+ var keys []reflect.Value
+ if expected.Len() >= actual.Len() {
+ keys = expected.MapKeys()
+ } else {
+ keys = actual.MapKeys()
+ }
+
+ for _, k := range keys {
+ expectedValue := expected.MapIndex(k)
+ actualValue := expected.MapIndex(k)
+
+ if !expectedValue.IsValid() {
+ logDifference(path, nil, actual.Interface())
+ return
+ }
+ if !actualValue.IsValid() {
+ logDifference(path, expected.Interface(), nil)
+ return
+ }
+
+ hop := append(path, fmt.Sprintf("[%v]", k))
+ deepDiffEqual(expectedValue, actualValue, visited, hop, logDifference)
+ }
+ return
+ case reflect.Func:
+ if expected.IsNil() != actual.IsNil() {
+ logDifference(path, expected.Interface(), actual.Interface())
+ }
+ return
+ default:
+ if expected.Interface() != actual.Interface() {
+ logDifference(path, expected.Interface(), actual.Interface())
+ }
+ }
+}
+
+func deepDiff(expected, actual interface{}, logDifference diffLogger) {
+ if expected == nil || actual == nil {
+ logDifference([]string{}, expected, actual)
+ return
+ }
+
+ expectedValue := reflect.ValueOf(expected)
+ actualValue := reflect.ValueOf(actual)
+
+ if expectedValue.Type() != actualValue.Type() {
+ logDifference([]string{}, expected, actual)
+ return
+ }
+ deepDiffEqual(expectedValue, actualValue, map[visit]bool{}, []string{}, logDifference)
}
// AssertEquals compares two arbitrary values and performs a comparison. If the
@@ -47,16 +219,33 @@
// AssertDeepEquals - like Equals - performs a comparison - but on more complex
// structures that requires deeper inspection
func AssertDeepEquals(t *testing.T, expected, actual interface{}) {
- if !reflect.DeepEqual(expected, actual) {
- logFatal(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
+ pre := prefix(2)
+
+ differed := false
+ deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
+ differed = true
+ t.Errorf("\033[1;31m%sat %s expected %s, but got %s\033[0m",
+ pre,
+ strings.Join(path, ""),
+ green(expected),
+ yellow(actual))
+ })
+ if differed {
+ logFatal(t, "The structures were different.")
}
}
// CheckDeepEquals is similar to AssertDeepEquals, except with a non-fatal error
func CheckDeepEquals(t *testing.T, expected, actual interface{}) {
- if !reflect.DeepEqual(expected, actual) {
- logError(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
- }
+ pre := prefix(2)
+
+ deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
+ t.Errorf("\033[1;31m%s at %s expected %s, but got %s\033[0m",
+ pre,
+ strings.Join(path, ""),
+ green(expected),
+ yellow(actual))
+ })
}
// AssertNoErr is a convenience function for checking whether an error value is