blob: f21c3f95d1c2a0a8aaa175417ea6c044c57a052b [file] [log] [blame]
Jamie Hannafordb2b237f2014-09-15 12:17:47 +02001package testhelper
2
Jamie Hannaford6bcf2582014-09-15 12:52:51 +02003import (
jrperrittc5c590a2016-11-04 14:41:15 -05004 "bytes"
Ash Wilson3315cf92014-10-23 10:27:35 -04005 "encoding/json"
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +02006 "fmt"
7 "path/filepath"
Jamie Hannaford6bcf2582014-09-15 12:52:51 +02008 "reflect"
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +02009 "runtime"
Ash Wilson4e034de2014-10-21 13:56:20 -040010 "strings"
Jamie Hannaford6bcf2582014-09-15 12:52:51 +020011 "testing"
12)
Jamie Hannafordb2b237f2014-09-15 12:17:47 +020013
Ash Wilson3315cf92014-10-23 10:27:35 -040014const (
15 logBodyFmt = "\033[1;31m%s %s\033[0m"
16 greenCode = "\033[0m\033[1;32m"
17 yellowCode = "\033[0m\033[1;33m"
18 resetCode = "\033[0m\033[1;31m"
19)
20
Ash Wilson4e034de2014-10-21 13:56:20 -040021func prefix(depth int) string {
22 _, file, line, _ := runtime.Caller(depth)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020023 return fmt.Sprintf("Failure in %s, line %d:", filepath.Base(file), line)
24}
25
26func green(str interface{}) string {
Ash Wilson3315cf92014-10-23 10:27:35 -040027 return fmt.Sprintf("%s%#v%s", greenCode, str, resetCode)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020028}
29
30func yellow(str interface{}) string {
Ash Wilson3315cf92014-10-23 10:27:35 -040031 return fmt.Sprintf("%s%#v%s", yellowCode, str, resetCode)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020032}
33
34func logFatal(t *testing.T, str string) {
Ash Wilson3315cf92014-10-23 10:27:35 -040035 t.Fatalf(logBodyFmt, prefix(3), str)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020036}
37
38func logError(t *testing.T, str string) {
Ash Wilson3315cf92014-10-23 10:27:35 -040039 t.Errorf(logBodyFmt, prefix(3), str)
Ash Wilson4e034de2014-10-21 13:56:20 -040040}
41
42type diffLogger func([]string, interface{}, interface{})
43
44type visit struct {
45 a1 uintptr
46 a2 uintptr
47 typ reflect.Type
48}
49
50// Recursively visits the structures of "expected" and "actual". The diffLogger function will be
51// invoked with each different value encountered, including the reference path that was followed
52// to get there.
53func deepDiffEqual(expected, actual reflect.Value, visited map[visit]bool, path []string, logDifference diffLogger) {
54 defer func() {
55 // Fall back to the regular reflect.DeepEquals function.
56 if r := recover(); r != nil {
57 var e, a interface{}
58 if expected.IsValid() {
59 e = expected.Interface()
60 }
61 if actual.IsValid() {
62 a = actual.Interface()
63 }
64
65 if !reflect.DeepEqual(e, a) {
66 logDifference(path, e, a)
67 }
68 }
69 }()
70
71 if !expected.IsValid() && actual.IsValid() {
72 logDifference(path, nil, actual.Interface())
73 return
74 }
75 if expected.IsValid() && !actual.IsValid() {
76 logDifference(path, expected.Interface(), nil)
77 return
78 }
79 if !expected.IsValid() && !actual.IsValid() {
80 return
81 }
82
83 hard := func(k reflect.Kind) bool {
84 switch k {
85 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
86 return true
87 }
88 return false
89 }
90
91 if expected.CanAddr() && actual.CanAddr() && hard(expected.Kind()) {
92 addr1 := expected.UnsafeAddr()
93 addr2 := actual.UnsafeAddr()
94
95 if addr1 > addr2 {
96 addr1, addr2 = addr2, addr1
97 }
98
99 if addr1 == addr2 {
100 // References are identical. We can short-circuit
101 return
102 }
103
104 typ := expected.Type()
105 v := visit{addr1, addr2, typ}
106 if visited[v] {
107 // Already visited.
108 return
109 }
110
111 // Remember this visit for later.
112 visited[v] = true
113 }
114
115 switch expected.Kind() {
116 case reflect.Array:
117 for i := 0; i < expected.Len(); i++ {
118 hop := append(path, fmt.Sprintf("[%d]", i))
119 deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
120 }
121 return
122 case reflect.Slice:
123 if expected.IsNil() != actual.IsNil() {
124 logDifference(path, expected.Interface(), actual.Interface())
125 return
126 }
127 if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
128 return
129 }
130 for i := 0; i < expected.Len(); i++ {
131 hop := append(path, fmt.Sprintf("[%d]", i))
132 deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
133 }
134 return
135 case reflect.Interface:
136 if expected.IsNil() != actual.IsNil() {
137 logDifference(path, expected.Interface(), actual.Interface())
138 return
139 }
140 deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
141 return
142 case reflect.Ptr:
143 deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
144 return
145 case reflect.Struct:
146 for i, n := 0, expected.NumField(); i < n; i++ {
147 field := expected.Type().Field(i)
148 hop := append(path, "."+field.Name)
149 deepDiffEqual(expected.Field(i), actual.Field(i), visited, hop, logDifference)
150 }
151 return
152 case reflect.Map:
153 if expected.IsNil() != actual.IsNil() {
154 logDifference(path, expected.Interface(), actual.Interface())
155 return
156 }
157 if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
158 return
159 }
160
161 var keys []reflect.Value
162 if expected.Len() >= actual.Len() {
163 keys = expected.MapKeys()
164 } else {
165 keys = actual.MapKeys()
166 }
167
168 for _, k := range keys {
169 expectedValue := expected.MapIndex(k)
170 actualValue := expected.MapIndex(k)
171
172 if !expectedValue.IsValid() {
173 logDifference(path, nil, actual.Interface())
174 return
175 }
176 if !actualValue.IsValid() {
177 logDifference(path, expected.Interface(), nil)
178 return
179 }
180
181 hop := append(path, fmt.Sprintf("[%v]", k))
182 deepDiffEqual(expectedValue, actualValue, visited, hop, logDifference)
183 }
184 return
185 case reflect.Func:
186 if expected.IsNil() != actual.IsNil() {
187 logDifference(path, expected.Interface(), actual.Interface())
188 }
189 return
190 default:
191 if expected.Interface() != actual.Interface() {
192 logDifference(path, expected.Interface(), actual.Interface())
193 }
194 }
195}
196
197func deepDiff(expected, actual interface{}, logDifference diffLogger) {
198 if expected == nil || actual == nil {
199 logDifference([]string{}, expected, actual)
200 return
201 }
202
203 expectedValue := reflect.ValueOf(expected)
204 actualValue := reflect.ValueOf(actual)
205
206 if expectedValue.Type() != actualValue.Type() {
207 logDifference([]string{}, expected, actual)
208 return
209 }
210 deepDiffEqual(expectedValue, actualValue, map[visit]bool{}, []string{}, logDifference)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200211}
212
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200213// AssertEquals compares two arbitrary values and performs a comparison. If the
Jamie Hannaford2964aed2014-09-15 12:20:02 +0200214// comparison fails, a fatal error is raised that will fail the test
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200215func AssertEquals(t *testing.T, expected, actual interface{}) {
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200216 if expected != actual {
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200217 logFatal(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200218 }
219}
220
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200221// CheckEquals is similar to AssertEquals, except with a non-fatal error
222func CheckEquals(t *testing.T, expected, actual interface{}) {
223 if expected != actual {
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200224 logError(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200225 }
226}
227
228// AssertDeepEquals - like Equals - performs a comparison - but on more complex
Jamie Hannaford6bcf2582014-09-15 12:52:51 +0200229// structures that requires deeper inspection
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200230func AssertDeepEquals(t *testing.T, expected, actual interface{}) {
Ash Wilson4e034de2014-10-21 13:56:20 -0400231 pre := prefix(2)
232
233 differed := false
234 deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
235 differed = true
236 t.Errorf("\033[1;31m%sat %s expected %s, but got %s\033[0m",
237 pre,
238 strings.Join(path, ""),
239 green(expected),
240 yellow(actual))
241 })
242 if differed {
243 logFatal(t, "The structures were different.")
Jamie Hannaford6bcf2582014-09-15 12:52:51 +0200244 }
245}
246
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200247// CheckDeepEquals is similar to AssertDeepEquals, except with a non-fatal error
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200248func CheckDeepEquals(t *testing.T, expected, actual interface{}) {
Ash Wilson4e034de2014-10-21 13:56:20 -0400249 pre := prefix(2)
250
251 deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
252 t.Errorf("\033[1;31m%s at %s expected %s, but got %s\033[0m",
253 pre,
254 strings.Join(path, ""),
255 green(expected),
256 yellow(actual))
257 })
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200258}
259
jrperrittc5c590a2016-11-04 14:41:15 -0500260func isByteArrayEquals(t *testing.T, expectedBytes []byte, actualBytes []byte) bool {
261 return bytes.Equal(expectedBytes, actualBytes)
262}
263
264// AssertByteArrayEquals a convenience function for checking whether two byte arrays are equal
265func AssertByteArrayEquals(t *testing.T, expectedBytes []byte, actualBytes []byte) {
266 if !isByteArrayEquals(t, expectedBytes, actualBytes) {
267 logFatal(t, "The bytes differed.")
268 }
269}
270
271// CheckByteArrayEquals a convenience function for silent checking whether two byte arrays are equal
272func CheckByteArrayEquals(t *testing.T, expectedBytes []byte, actualBytes []byte) {
273 if !isByteArrayEquals(t, expectedBytes, actualBytes) {
274 logError(t, "The bytes differed.")
275 }
276}
277
Ash Wilson3315cf92014-10-23 10:27:35 -0400278// isJSONEquals is a utility function that implements JSON comparison for AssertJSONEquals and
279// CheckJSONEquals.
280func isJSONEquals(t *testing.T, expectedJSON string, actual interface{}) bool {
Jon Perritt485b8aa2014-10-24 12:51:16 -0500281 var parsedExpected, parsedActual interface{}
Ash Wilson3315cf92014-10-23 10:27:35 -0400282 err := json.Unmarshal([]byte(expectedJSON), &parsedExpected)
283 if err != nil {
284 t.Errorf("Unable to parse expected value as JSON: %v", err)
285 return false
286 }
287
Jon Perritt485b8aa2014-10-24 12:51:16 -0500288 jsonActual, err := json.Marshal(actual)
289 AssertNoErr(t, err)
290 err = json.Unmarshal(jsonActual, &parsedActual)
291 AssertNoErr(t, err)
292
293 if !reflect.DeepEqual(parsedExpected, parsedActual) {
Ash Wilson3315cf92014-10-23 10:27:35 -0400294 prettyExpected, err := json.MarshalIndent(parsedExpected, "", " ")
295 if err != nil {
296 t.Logf("Unable to pretty-print expected JSON: %v\n%s", err, expectedJSON)
297 } else {
298 // We can't use green() here because %#v prints prettyExpected as a byte array literal, which
299 // is... unhelpful. Converting it to a string first leaves "\n" uninterpreted for some reason.
300 t.Logf("Expected JSON:\n%s%s%s", greenCode, prettyExpected, resetCode)
301 }
302
303 prettyActual, err := json.MarshalIndent(actual, "", " ")
304 if err != nil {
305 t.Logf("Unable to pretty-print actual JSON: %v\n%#v", err, actual)
306 } else {
307 // We can't use yellow() for the same reason.
308 t.Logf("Actual JSON:\n%s%s%s", yellowCode, prettyActual, resetCode)
309 }
310
311 return false
312 }
313 return true
314}
315
316// AssertJSONEquals serializes a value as JSON, parses an expected string as JSON, and ensures that
317// both are consistent. If they aren't, the expected and actual structures are pretty-printed and
318// shown for comparison.
319//
320// This is useful for comparing structures that are built as nested map[string]interface{} values,
321// which are a pain to construct as literals.
322func AssertJSONEquals(t *testing.T, expectedJSON string, actual interface{}) {
323 if !isJSONEquals(t, expectedJSON, actual) {
324 logFatal(t, "The generated JSON structure differed.")
325 }
326}
327
328// CheckJSONEquals is similar to AssertJSONEquals, but nonfatal.
329func CheckJSONEquals(t *testing.T, expectedJSON string, actual interface{}) {
330 if !isJSONEquals(t, expectedJSON, actual) {
331 logError(t, "The generated JSON structure differed.")
332 }
333}
334
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200335// AssertNoErr is a convenience function for checking whether an error value is
336// an actual error
337func AssertNoErr(t *testing.T, e error) {
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200338 if e != nil {
Jamie Hannaford9823bb62014-09-26 17:06:36 +0200339 logFatal(t, fmt.Sprintf("unexpected error %s", yellow(e.Error())))
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200340 }
341}
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200342
343// CheckNoErr is similar to AssertNoErr, except with a non-fatal error
344func CheckNoErr(t *testing.T, e error) {
345 if e != nil {
Jamie Hannaford9823bb62014-09-26 17:06:36 +0200346 logError(t, fmt.Sprintf("unexpected error %s", yellow(e.Error())))
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200347 }
348}