blob: adb77e5aac9bf51698c0108d843da2fe89148032 [file] [log] [blame]
Jamie Hannafordb2b237f2014-09-15 12:17:47 +02001package testhelper
2
Jamie Hannaford6bcf2582014-09-15 12:52:51 +02003import (
Ash Wilson3315cf92014-10-23 10:27:35 -04004 "encoding/json"
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +02005 "fmt"
6 "path/filepath"
Jamie Hannaford6bcf2582014-09-15 12:52:51 +02007 "reflect"
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +02008 "runtime"
Ash Wilson4e034de2014-10-21 13:56:20 -04009 "strings"
Jamie Hannaford6bcf2582014-09-15 12:52:51 +020010 "testing"
11)
Jamie Hannafordb2b237f2014-09-15 12:17:47 +020012
Ash Wilson3315cf92014-10-23 10:27:35 -040013const (
14 logBodyFmt = "\033[1;31m%s %s\033[0m"
15 greenCode = "\033[0m\033[1;32m"
16 yellowCode = "\033[0m\033[1;33m"
17 resetCode = "\033[0m\033[1;31m"
18)
19
Ash Wilson4e034de2014-10-21 13:56:20 -040020func prefix(depth int) string {
21 _, file, line, _ := runtime.Caller(depth)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020022 return fmt.Sprintf("Failure in %s, line %d:", filepath.Base(file), line)
23}
24
25func green(str interface{}) string {
Ash Wilson3315cf92014-10-23 10:27:35 -040026 return fmt.Sprintf("%s%#v%s", greenCode, str, resetCode)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020027}
28
29func yellow(str interface{}) string {
Ash Wilson3315cf92014-10-23 10:27:35 -040030 return fmt.Sprintf("%s%#v%s", yellowCode, str, resetCode)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020031}
32
33func logFatal(t *testing.T, str string) {
Ash Wilson3315cf92014-10-23 10:27:35 -040034 t.Fatalf(logBodyFmt, prefix(3), str)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +020035}
36
37func logError(t *testing.T, str string) {
Ash Wilson3315cf92014-10-23 10:27:35 -040038 t.Errorf(logBodyFmt, prefix(3), str)
Ash Wilson4e034de2014-10-21 13:56:20 -040039}
40
41type diffLogger func([]string, interface{}, interface{})
42
43type visit struct {
44 a1 uintptr
45 a2 uintptr
46 typ reflect.Type
47}
48
49// Recursively visits the structures of "expected" and "actual". The diffLogger function will be
50// invoked with each different value encountered, including the reference path that was followed
51// to get there.
52func deepDiffEqual(expected, actual reflect.Value, visited map[visit]bool, path []string, logDifference diffLogger) {
53 defer func() {
54 // Fall back to the regular reflect.DeepEquals function.
55 if r := recover(); r != nil {
56 var e, a interface{}
57 if expected.IsValid() {
58 e = expected.Interface()
59 }
60 if actual.IsValid() {
61 a = actual.Interface()
62 }
63
64 if !reflect.DeepEqual(e, a) {
65 logDifference(path, e, a)
66 }
67 }
68 }()
69
70 if !expected.IsValid() && actual.IsValid() {
71 logDifference(path, nil, actual.Interface())
72 return
73 }
74 if expected.IsValid() && !actual.IsValid() {
75 logDifference(path, expected.Interface(), nil)
76 return
77 }
78 if !expected.IsValid() && !actual.IsValid() {
79 return
80 }
81
82 hard := func(k reflect.Kind) bool {
83 switch k {
84 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
85 return true
86 }
87 return false
88 }
89
90 if expected.CanAddr() && actual.CanAddr() && hard(expected.Kind()) {
91 addr1 := expected.UnsafeAddr()
92 addr2 := actual.UnsafeAddr()
93
94 if addr1 > addr2 {
95 addr1, addr2 = addr2, addr1
96 }
97
98 if addr1 == addr2 {
99 // References are identical. We can short-circuit
100 return
101 }
102
103 typ := expected.Type()
104 v := visit{addr1, addr2, typ}
105 if visited[v] {
106 // Already visited.
107 return
108 }
109
110 // Remember this visit for later.
111 visited[v] = true
112 }
113
114 switch expected.Kind() {
115 case reflect.Array:
116 for i := 0; i < expected.Len(); i++ {
117 hop := append(path, fmt.Sprintf("[%d]", i))
118 deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
119 }
120 return
121 case reflect.Slice:
122 if expected.IsNil() != actual.IsNil() {
123 logDifference(path, expected.Interface(), actual.Interface())
124 return
125 }
126 if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
127 return
128 }
129 for i := 0; i < expected.Len(); i++ {
130 hop := append(path, fmt.Sprintf("[%d]", i))
131 deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
132 }
133 return
134 case reflect.Interface:
135 if expected.IsNil() != actual.IsNil() {
136 logDifference(path, expected.Interface(), actual.Interface())
137 return
138 }
139 deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
140 return
141 case reflect.Ptr:
142 deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
143 return
144 case reflect.Struct:
145 for i, n := 0, expected.NumField(); i < n; i++ {
146 field := expected.Type().Field(i)
147 hop := append(path, "."+field.Name)
148 deepDiffEqual(expected.Field(i), actual.Field(i), visited, hop, logDifference)
149 }
150 return
151 case reflect.Map:
152 if expected.IsNil() != actual.IsNil() {
153 logDifference(path, expected.Interface(), actual.Interface())
154 return
155 }
156 if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
157 return
158 }
159
160 var keys []reflect.Value
161 if expected.Len() >= actual.Len() {
162 keys = expected.MapKeys()
163 } else {
164 keys = actual.MapKeys()
165 }
166
167 for _, k := range keys {
168 expectedValue := expected.MapIndex(k)
169 actualValue := expected.MapIndex(k)
170
171 if !expectedValue.IsValid() {
172 logDifference(path, nil, actual.Interface())
173 return
174 }
175 if !actualValue.IsValid() {
176 logDifference(path, expected.Interface(), nil)
177 return
178 }
179
180 hop := append(path, fmt.Sprintf("[%v]", k))
181 deepDiffEqual(expectedValue, actualValue, visited, hop, logDifference)
182 }
183 return
184 case reflect.Func:
185 if expected.IsNil() != actual.IsNil() {
186 logDifference(path, expected.Interface(), actual.Interface())
187 }
188 return
189 default:
190 if expected.Interface() != actual.Interface() {
191 logDifference(path, expected.Interface(), actual.Interface())
192 }
193 }
194}
195
196func deepDiff(expected, actual interface{}, logDifference diffLogger) {
197 if expected == nil || actual == nil {
198 logDifference([]string{}, expected, actual)
199 return
200 }
201
202 expectedValue := reflect.ValueOf(expected)
203 actualValue := reflect.ValueOf(actual)
204
205 if expectedValue.Type() != actualValue.Type() {
206 logDifference([]string{}, expected, actual)
207 return
208 }
209 deepDiffEqual(expectedValue, actualValue, map[visit]bool{}, []string{}, logDifference)
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200210}
211
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200212// AssertEquals compares two arbitrary values and performs a comparison. If the
Jamie Hannaford2964aed2014-09-15 12:20:02 +0200213// comparison fails, a fatal error is raised that will fail the test
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200214func AssertEquals(t *testing.T, expected, actual interface{}) {
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200215 if expected != actual {
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200216 logFatal(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200217 }
218}
219
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200220// CheckEquals is similar to AssertEquals, except with a non-fatal error
221func CheckEquals(t *testing.T, expected, actual interface{}) {
222 if expected != actual {
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200223 logError(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200224 }
225}
226
227// AssertDeepEquals - like Equals - performs a comparison - but on more complex
Jamie Hannaford6bcf2582014-09-15 12:52:51 +0200228// structures that requires deeper inspection
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200229func AssertDeepEquals(t *testing.T, expected, actual interface{}) {
Ash Wilson4e034de2014-10-21 13:56:20 -0400230 pre := prefix(2)
231
232 differed := false
233 deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
234 differed = true
235 t.Errorf("\033[1;31m%sat %s expected %s, but got %s\033[0m",
236 pre,
237 strings.Join(path, ""),
238 green(expected),
239 yellow(actual))
240 })
241 if differed {
242 logFatal(t, "The structures were different.")
Jamie Hannaford6bcf2582014-09-15 12:52:51 +0200243 }
244}
245
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200246// CheckDeepEquals is similar to AssertDeepEquals, except with a non-fatal error
Jamie Hannaford6d8dcd02014-09-25 13:55:09 +0200247func CheckDeepEquals(t *testing.T, expected, actual interface{}) {
Ash Wilson4e034de2014-10-21 13:56:20 -0400248 pre := prefix(2)
249
250 deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
251 t.Errorf("\033[1;31m%s at %s expected %s, but got %s\033[0m",
252 pre,
253 strings.Join(path, ""),
254 green(expected),
255 yellow(actual))
256 })
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200257}
258
Ash Wilson3315cf92014-10-23 10:27:35 -0400259// isJSONEquals is a utility function that implements JSON comparison for AssertJSONEquals and
260// CheckJSONEquals.
261func isJSONEquals(t *testing.T, expectedJSON string, actual interface{}) bool {
262 var parsedExpected interface{}
263 err := json.Unmarshal([]byte(expectedJSON), &parsedExpected)
264 if err != nil {
265 t.Errorf("Unable to parse expected value as JSON: %v", err)
266 return false
267 }
268
269 if !reflect.DeepEqual(parsedExpected, actual) {
270 prettyExpected, err := json.MarshalIndent(parsedExpected, "", " ")
271 if err != nil {
272 t.Logf("Unable to pretty-print expected JSON: %v\n%s", err, expectedJSON)
273 } else {
274 // We can't use green() here because %#v prints prettyExpected as a byte array literal, which
275 // is... unhelpful. Converting it to a string first leaves "\n" uninterpreted for some reason.
276 t.Logf("Expected JSON:\n%s%s%s", greenCode, prettyExpected, resetCode)
277 }
278
279 prettyActual, err := json.MarshalIndent(actual, "", " ")
280 if err != nil {
281 t.Logf("Unable to pretty-print actual JSON: %v\n%#v", err, actual)
282 } else {
283 // We can't use yellow() for the same reason.
284 t.Logf("Actual JSON:\n%s%s%s", yellowCode, prettyActual, resetCode)
285 }
286
287 return false
288 }
289 return true
290}
291
292// AssertJSONEquals serializes a value as JSON, parses an expected string as JSON, and ensures that
293// both are consistent. If they aren't, the expected and actual structures are pretty-printed and
294// shown for comparison.
295//
296// This is useful for comparing structures that are built as nested map[string]interface{} values,
297// which are a pain to construct as literals.
298func AssertJSONEquals(t *testing.T, expectedJSON string, actual interface{}) {
299 if !isJSONEquals(t, expectedJSON, actual) {
300 logFatal(t, "The generated JSON structure differed.")
301 }
302}
303
304// CheckJSONEquals is similar to AssertJSONEquals, but nonfatal.
305func CheckJSONEquals(t *testing.T, expectedJSON string, actual interface{}) {
306 if !isJSONEquals(t, expectedJSON, actual) {
307 logError(t, "The generated JSON structure differed.")
308 }
309}
310
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200311// AssertNoErr is a convenience function for checking whether an error value is
312// an actual error
313func AssertNoErr(t *testing.T, e error) {
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200314 if e != nil {
Jamie Hannaford9823bb62014-09-26 17:06:36 +0200315 logFatal(t, fmt.Sprintf("unexpected error %s", yellow(e.Error())))
Jamie Hannafordb2b237f2014-09-15 12:17:47 +0200316 }
317}
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200318
319// CheckNoErr is similar to AssertNoErr, except with a non-fatal error
320func CheckNoErr(t *testing.T, e error) {
321 if e != nil {
Jamie Hannaford9823bb62014-09-26 17:06:36 +0200322 logError(t, fmt.Sprintf("unexpected error %s", yellow(e.Error())))
Jamie Hannaford0f26e5c2014-09-15 15:46:58 +0200323 }
324}