blob: ca27cad11fbc1b8f16197b5b61e6a6f31febca13 [file] [log] [blame]
Jamie Hannafordb2b237f2014-09-15 12:17:47 +02001package testhelper
2
Jamie Hannaford6bcf2582014-09-15 12:52:51 +02003import (
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +02004 "fmt"
5 "path/filepath"
Jamie Hannaford6bcf2582014-09-15 12:52:51 +02006 "reflect"
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +02007 "runtime"
Ash Wilson4e034de2014-10-21 13:56:20 -04008 "strings"
Jamie Hannaford6bcf2582014-09-15 12:52:51 +02009 "testing"
10)
Jamie Hannafordb2b237f2014-09-15 12:17:47 +020011
Ash Wilson4e034de2014-10-21 13:56:20 -040012func prefix(depth int) string {
13 _, file, line, _ := runtime.Caller(depth)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020014 return fmt.Sprintf("Failure in %s, line %d:", filepath.Base(file), line)
15}
16
17func green(str interface{}) string {
18 return fmt.Sprintf("\033[0m\033[1;32m%#v\033[0m\033[1;31m", str)
19}
20
21func yellow(str interface{}) string {
22 return fmt.Sprintf("\033[0m\033[1;33m%#v\033[0m\033[1;31m", str)
23}
24
25func logFatal(t *testing.T, str string) {
Ash Wilson4e034de2014-10-21 13:56:20 -040026 t.Fatalf("\033[1;31m%s %s\033[0m", prefix(3), str)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020027}
28
29func logError(t *testing.T, str string) {
Ash Wilson4e034de2014-10-21 13:56:20 -040030 t.Errorf("\033[1;31m%s %s\033[0m", prefix(3), str)
31}
32
33type diffLogger func([]string, interface{}, interface{})
34
35type visit struct {
36 a1 uintptr
37 a2 uintptr
38 typ reflect.Type
39}
40
41// Recursively visits the structures of "expected" and "actual". The diffLogger function will be
42// invoked with each different value encountered, including the reference path that was followed
43// to get there.
44func deepDiffEqual(expected, actual reflect.Value, visited map[visit]bool, path []string, logDifference diffLogger) {
45 defer func() {
46 // Fall back to the regular reflect.DeepEquals function.
47 if r := recover(); r != nil {
48 var e, a interface{}
49 if expected.IsValid() {
50 e = expected.Interface()
51 }
52 if actual.IsValid() {
53 a = actual.Interface()
54 }
55
56 if !reflect.DeepEqual(e, a) {
57 logDifference(path, e, a)
58 }
59 }
60 }()
61
62 if !expected.IsValid() && actual.IsValid() {
63 logDifference(path, nil, actual.Interface())
64 return
65 }
66 if expected.IsValid() && !actual.IsValid() {
67 logDifference(path, expected.Interface(), nil)
68 return
69 }
70 if !expected.IsValid() && !actual.IsValid() {
71 return
72 }
73
74 hard := func(k reflect.Kind) bool {
75 switch k {
76 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
77 return true
78 }
79 return false
80 }
81
82 if expected.CanAddr() && actual.CanAddr() && hard(expected.Kind()) {
83 addr1 := expected.UnsafeAddr()
84 addr2 := actual.UnsafeAddr()
85
86 if addr1 > addr2 {
87 addr1, addr2 = addr2, addr1
88 }
89
90 if addr1 == addr2 {
91 // References are identical. We can short-circuit
92 return
93 }
94
95 typ := expected.Type()
96 v := visit{addr1, addr2, typ}
97 if visited[v] {
98 // Already visited.
99 return
100 }
101
102 // Remember this visit for later.
103 visited[v] = true
104 }
105
106 switch expected.Kind() {
107 case reflect.Array:
108 for i := 0; i < expected.Len(); i++ {
109 hop := append(path, fmt.Sprintf("[%d]", i))
110 deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
111 }
112 return
113 case reflect.Slice:
114 if expected.IsNil() != actual.IsNil() {
115 logDifference(path, expected.Interface(), actual.Interface())
116 return
117 }
118 if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
119 return
120 }
121 for i := 0; i < expected.Len(); i++ {
122 hop := append(path, fmt.Sprintf("[%d]", i))
123 deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
124 }
125 return
126 case reflect.Interface:
127 if expected.IsNil() != actual.IsNil() {
128 logDifference(path, expected.Interface(), actual.Interface())
129 return
130 }
131 deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
132 return
133 case reflect.Ptr:
134 deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
135 return
136 case reflect.Struct:
137 for i, n := 0, expected.NumField(); i < n; i++ {
138 field := expected.Type().Field(i)
139 hop := append(path, "."+field.Name)
140 deepDiffEqual(expected.Field(i), actual.Field(i), visited, hop, logDifference)
141 }
142 return
143 case reflect.Map:
144 if expected.IsNil() != actual.IsNil() {
145 logDifference(path, expected.Interface(), actual.Interface())
146 return
147 }
148 if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
149 return
150 }
151
152 var keys []reflect.Value
153 if expected.Len() >= actual.Len() {
154 keys = expected.MapKeys()
155 } else {
156 keys = actual.MapKeys()
157 }
158
159 for _, k := range keys {
160 expectedValue := expected.MapIndex(k)
161 actualValue := expected.MapIndex(k)
162
163 if !expectedValue.IsValid() {
164 logDifference(path, nil, actual.Interface())
165 return
166 }
167 if !actualValue.IsValid() {
168 logDifference(path, expected.Interface(), nil)
169 return
170 }
171
172 hop := append(path, fmt.Sprintf("[%v]", k))
173 deepDiffEqual(expectedValue, actualValue, visited, hop, logDifference)
174 }
175 return
176 case reflect.Func:
177 if expected.IsNil() != actual.IsNil() {
178 logDifference(path, expected.Interface(), actual.Interface())
179 }
180 return
181 default:
182 if expected.Interface() != actual.Interface() {
183 logDifference(path, expected.Interface(), actual.Interface())
184 }
185 }
186}
187
188func deepDiff(expected, actual interface{}, logDifference diffLogger) {
189 if expected == nil || actual == nil {
190 logDifference([]string{}, expected, actual)
191 return
192 }
193
194 expectedValue := reflect.ValueOf(expected)
195 actualValue := reflect.ValueOf(actual)
196
197 if expectedValue.Type() != actualValue.Type() {
198 logDifference([]string{}, expected, actual)
199 return
200 }
201 deepDiffEqual(expectedValue, actualValue, map[visit]bool{}, []string{}, logDifference)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200202}
203
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200204// AssertEquals compares two arbitrary values and performs a comparison. If the
Jamie Hannaford2964aed2014-09-15 12:20:02 +0200205// comparison fails, a fatal error is raised that will fail the test
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200206func AssertEquals(t *testing.T, expected, actual interface{}) {
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200207 if expected != actual {
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200208 logFatal(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200209 }
210}
211
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200212// CheckEquals is similar to AssertEquals, except with a non-fatal error
213func CheckEquals(t *testing.T, expected, actual interface{}) {
214 if expected != actual {
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200215 logError(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200216 }
217}
218
219// AssertDeepEquals - like Equals - performs a comparison - but on more complex
Jamie Hannaford6bcf2582014-09-15 12:52:51 +0200220// structures that requires deeper inspection
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200221func AssertDeepEquals(t *testing.T, expected, actual interface{}) {
Ash Wilson4e034de2014-10-21 13:56:20 -0400222 pre := prefix(2)
223
224 differed := false
225 deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
226 differed = true
227 t.Errorf("\033[1;31m%sat %s expected %s, but got %s\033[0m",
228 pre,
229 strings.Join(path, ""),
230 green(expected),
231 yellow(actual))
232 })
233 if differed {
234 logFatal(t, "The structures were different.")
Jamie Hannaford6bcf2582014-09-15 12:52:51 +0200235 }
236}
237
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200238// CheckDeepEquals is similar to AssertDeepEquals, except with a non-fatal error
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200239func CheckDeepEquals(t *testing.T, expected, actual interface{}) {
Ash Wilson4e034de2014-10-21 13:56:20 -0400240 pre := prefix(2)
241
242 deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
243 t.Errorf("\033[1;31m%s at %s expected %s, but got %s\033[0m",
244 pre,
245 strings.Join(path, ""),
246 green(expected),
247 yellow(actual))
248 })
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200249}
250
251// AssertNoErr is a convenience function for checking whether an error value is
252// an actual error
253func AssertNoErr(t *testing.T, e error) {
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200254 if e != nil {
Jamie Hannaford9823bb62014-09-26 17:06:36 +0200255 logFatal(t, fmt.Sprintf("unexpected error %s", yellow(e.Error())))
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200256 }
257}
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200258
259// CheckNoErr is similar to AssertNoErr, except with a non-fatal error
260func CheckNoErr(t *testing.T, e error) {
261 if e != nil {
Jamie Hannaford9823bb62014-09-26 17:06:36 +0200262 logError(t, fmt.Sprintf("unexpected error %s", yellow(e.Error())))
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200263 }
264}