THRIFT-5423: IDL parameter validation for Go
Closes https://github.com/apache/thrift/pull/2469.
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 992a843..b9c00d9 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -64,7 +64,8 @@
ConstOptionalFieldImport.thrift \
ConstOptionalField.thrift \
ProcessorMiddlewareTest.thrift \
- ClientMiddlewareExceptionTest.thrift
+ ClientMiddlewareExceptionTest.thrift \
+ ValidateTest.thrift
mkdir -p gopath/src
grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -98,6 +99,7 @@
$(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift
$(THRIFT) $(THRIFTARGS_SKIP_REMOTE) ProcessorMiddlewareTest.thrift
$(THRIFT) $(THRIFTARGS) ClientMiddlewareExceptionTest.thrift
+ $(THRIFT) $(THRIFTARGS) ValidateTest.thrift
ln -nfs ../../tests gopath/src/tests
cp -r ./dontexportrwtest gopath/src
touch gopath
@@ -122,7 +124,8 @@
./gopath/src/equalstest \
./gopath/src/conflictargnamestest \
./gopath/src/processormiddlewaretest \
- ./gopath/src/clientmiddlewareexceptiontest
+ ./gopath/src/clientmiddlewareexceptiontest \
+ ./gopath/src/validatetest
$(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift
$(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest
@@ -168,4 +171,5 @@
ServicesTest.thrift \
TypedefFieldTest.thrift \
UnionBinaryTest.thrift \
- UnionDefaultValueTest.thrift
+ UnionDefaultValueTest.thrift \
+ ValidateTest.thrift
diff --git a/lib/go/test/ValidateTest.thrift b/lib/go/test/ValidateTest.thrift
new file mode 100644
index 0000000..c02bfa8
--- /dev/null
+++ b/lib/go/test/ValidateTest.thrift
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+namespace go validatetest
+
+enum EnumFoo {
+ e1
+ e2
+}
+
+struct Foo {
+ 1: bool Bool
+}
+
+struct BasicTest {
+ 1: bool Bool0 = true (vt.const = "true")
+ 2: optional bool Bool1 (vt.const = "true")
+ 3: i8 Byte0 = 1 (vt.lt = "2", vt.le = "2", vt.gt = "0", vt.ge = "0", vt.in = "[0, 1, 2]", vt.not_in = "[3, 4, 5]")
+ 4: optional i8 Byte1 (vt.lt = "1", vt.le = "1", vt.gt = "-1", vt.ge = "-1", vt.in = "[-1, 0, 1]", vt.not_in = "[1, 2, 3]")
+ 5: double Double0 = 1.0 (vt.lt = "2.0", vt.le = "2.0", vt.gt = "0", vt.ge = "0", vt.in = "[0, 1.0, 2.0]", vt.not_in = "[3.0, 4.0, 5.0]")
+ 6: optional double Double1 (vt.lt = "2.0", vt.le = "2.0", vt.gt = "0", vt.ge = "0", vt.in = "[0, 1.0, 2.0]", vt.not_in = "[3.0, 4.0, 5.0]")
+ 7: string String0 = "my const string" (vt.const = "my const string", vt.min_size = "0", vt.max_size = "100", vt.pattern = ".*", vt.prefix = "my", vt.suffix = "string", vt.contains = "const", vt.not_contains = "oh")
+ 8: optional string String1 (vt.const = "my const string", vt.min_size = "0", vt.max_size = "100", vt.pattern = ".*", vt.prefix = "my", vt.suffix = "string", vt.contains = "const", vt.not_contains = "oh")
+ 9: binary Binary0 = "my const string" (vt.const = "my const string", vt.min_size = "0", vt.max_size = "100", vt.pattern = ".*", vt.prefix = "my", vt.suffix = "string", vt.contains = "const", vt.not_contains = "oh")
+ 10: optional binary Binary1 = "my const string" (vt.const = "my const string", vt.min_size = "0", vt.max_size = "100", vt.pattern = ".*", vt.prefix = "my", vt.suffix = "string", vt.contains = "const", vt.not_contains = "oh")
+ 11: map<string, string> Map0 (vt.min_size = "0", vt.max_size = "10", vt.key.min_size = "0", vt.key.max_size = "10", vt.value.min_size = "0", vt.value.max_size = "10")
+ 12: optional map<string, string> Map1 (vt.min_size = "0", vt.max_size = "10", vt.key.min_size = "0", vt.key.max_size = "10", vt.value.min_size = "0", vt.value.max_size = "10")
+ 13: set<string> Set0 (vt.min_size = "0", vt.max_size = "10", vt.elem.min_size = "5")
+ 14: optional set<string> Set1 (vt.min_size = "0", vt.max_size = "10", vt.elem.min_size = "5")
+ 15: EnumFoo Enum0 = EnumFoo.e2 (vt.in = "[EnumFoo.e2]", vt.defined_only = "true")
+ 16: optional EnumFoo Enum1 (vt.in = "[EnumFoo.e1]", vt.defined_only = "true")
+ 17: Foo Struct0 (vt.skip = "true")
+ 18: optional Foo Struct1 (vt.skip = "true")
+ 19: i8 Byte2 = 1 (vt.in = "1", vt.not_in = "2")
+ 20: double Double2 = 3.0 (vt.in = "3.0", vt.not_in = "4.0")
+ 21: EnumFoo Enum2 = EnumFoo.e2 (vt.in = "EnumFoo.e2", vt.not_in = "EnumFoo.e1")
+}
+
+struct FieldReferenceTest {
+ 1: bool Bool0 (vt.const = "$Bool2")
+ 2: optional bool Bool1 (vt.const = "$Bool2")
+ 3: i8 Byte0 = 10 (vt.lt = "$Byte4", vt.le = "$Byte4", vt.gt = "$Byte2", vt.ge = "$Byte2", vt.in = "[$Byte2, $Byte3, $Byte4]", vt.not_in = "[$Byte2, $Byte4]")
+ 4: optional i8 Byte1 (vt.lt = "$Byte4", vt.le = "$Byte4", vt.gt = "$Byte2", vt.ge = "$Byte2", vt.in = "[$Byte2, $Byte3, $Byte4]", vt.not_in = "[$Byte2, $Byte4]")
+ 5: double Double0 = 10.0 (vt.lt = "$Double4", vt.le = "$Double4", vt.gt = "$Double2", vt.ge = "$Double2", vt.in = "[$Double2, $Double3, $Double4]", vt.not_in = "[$Double2, $Double4]")
+ 6: optional double Double1 (vt.lt = "$Double4", vt.le = "$Double4", vt.gt = "$Double2", vt.ge = "$Double2", vt.in = "[$Double2, $Double3, $Double4]", vt.not_in = "[$Double2, $Double4]")
+ 7: string String0 = "my string" (vt.const = "$String2", vt.min_size = "$Byte2", vt.max_size = "$Byte3", vt.pattern = "$String4", vt.prefix = "$String2", vt.suffix = "$String2", vt.contains = "$String2", vt.not_contains = "$String3")
+ 8: optional string String1 (vt.const = "$String2", vt.min_size = "$Byte2", vt.max_size = "$Byte3", vt.pattern = "$String4", vt.prefix = "$String2", vt.suffix = "$String2", vt.contains = "$String2", vt.not_contains = "$String3")
+ 9: binary Binary0 = "my binary" (vt.const = "$Binary2", vt.min_size = "$Byte2", vt.max_size = "$Byte3", vt.pattern = "$Binary4", vt.prefix = "$Binary2", vt.suffix = "$Binary2", vt.contains = "$Binary2", vt.not_contains = "$Binary3")
+ 10: optional binary Binary1 = "my binary" (vt.const = "$Binary2", vt.min_size = "$Byte2", vt.max_size = "$Byte3", vt.pattern = "$Binary4", vt.prefix = "$Binary2", vt.suffix = "$Binary2", vt.contains = "$Binary2", vt.not_contains = "$Binary3")
+ 11: map<string, string> Map0 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.key.min_size = "$Byte2", vt.key.max_size = "$MaxSize", vt.value.min_size = "$Byte2", vt.value.max_size = "$MaxSize")
+ 12: optional map<string, string> Map1 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.key.min_size = "$Byte2", vt.key.max_size = "$MaxSize", vt.value.min_size = "$Byte2", vt.value.max_size = "$MaxSize")
+ 13: list<string> List0 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.elem.min_size = "$Byte2", vt.elem.max_size = "$MaxSize")
+ 14: optional list<string> List1 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.elem.min_size = "$Byte2", vt.elem.max_size = "$MaxSize")
+ 15: set<string> Set0 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.elem.min_size = "$Byte2", vt.elem.max_size = "$MaxSize")
+ 16: optional set<string> Set1 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.elem.min_size = "$Byte2", vt.elem.max_size = "$MaxSize")
+ 17: bool Bool2 = false
+ 18: i8 Byte2 = 0
+ 19: i8 Byte3 = 10
+ 20: i8 Byte4 = 20
+ 21: double Double2 = 0
+ 22: double Double3 = 10.0
+ 23: double Double4 = 20.0
+ 24: string String2 = "my string"
+ 25: string String3 = "other string"
+ 26: string String4 = ".*"
+ 27: binary Binary2 = "my binary"
+ 28: binary Binary3 = "other binary"
+ 29: binary Binary4 = ".*"
+ 30: i64 MaxSize = 10
+}
+
+struct ValidationFunctionTest {
+ 1: string StringFoo
+ 2: i64 StringLength (vt.in = "[@len($StringFoo)]")
+}
+
+struct AnnotationCompatibleTest {
+ 1: bool Bool0 = true (vt.const = "true", go.tag = 'json:"bool1"')
+ 2: i8 Byte0 = 1 (vt.lt = "2", go.tag = 'json:"byte1"')
+ 3: double Double0 = 1.0 (vt.lt = "2.0", go.tag = 'json:"double1"')
+ 4: string String0 = "my const string" (vt.const = "my const string", go.tag = 'json:"string1"')
+ 5: binary Binary0 = "my const string" (vt.const = "my const string", go.tag = 'json:"binary1"')
+ 6: map<string, string> Map0 (vt.max_size = "2", go.tag = 'json:"map1"')
+ 7: set<string> Set0 (vt.max_size = "2", go.tag = 'json:"set1"')
+ 8: list<string> List0 (vt.max_size = "2", go.tag = 'json:"list1"')
+ 9: EnumFoo Enum0 = EnumFoo.e2 (vt.in = "[EnumFoo.e2]", go.tag = 'json:"enum1"')
+ 10: Foo Struct0 (vt.skip = "true", go.tag = 'json:"struct1"')
+}
diff --git a/lib/go/test/tests/validate_test.go b/lib/go/test/tests/validate_test.go
new file mode 100644
index 0000000..957a8df
--- /dev/null
+++ b/lib/go/test/tests/validate_test.go
@@ -0,0 +1,494 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+ "encoding/json"
+ "errors"
+ "strconv"
+ "testing"
+
+ "github.com/apache/thrift/lib/go/test/gopath/src/validatetest"
+ thrift "github.com/apache/thrift/lib/go/thrift"
+)
+
+func TestBasicValidator(t *testing.T) {
+ bt := validatetest.NewBasicTest()
+ if err := bt.Validate(); err != nil {
+ t.Error(err)
+ }
+ var ve *thrift.ValidationError
+ bt = validatetest.NewBasicTest()
+ bt.Bool1 = thrift.BoolPtr(false)
+ if err := bt.Validate(); err == nil {
+ t.Error("Expected vt.const error for Bool1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Bool1" {
+ t.Errorf("Expected error for Bool1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Byte1 = thrift.Int8Ptr(3)
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Byte1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Byte1" {
+ t.Errorf("Expected error for Byte1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Double1 = thrift.Float64Ptr(3.0)
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Double1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Double1" {
+ t.Errorf("Expected error for Double1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.String1 = thrift.StringPtr("other string")
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for String1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "String1" {
+ t.Errorf("Expected error for String1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Binary1 = []byte("other binary")
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Binary1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Binary1" {
+ t.Errorf("Expected error for Binary1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Map1 = make(map[string]string)
+ for i := 0; i < 11; i++ {
+ bt.Map1[strconv.Itoa(i)] = strconv.Itoa(i)
+ }
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map1" {
+ t.Errorf("Expected error for Map1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt.Map1 = map[string]string{"012345678910": "0"}
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map1" {
+ t.Errorf("Expected error for Map1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt.Map1 = map[string]string{"0": "012345678910"}
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map1" {
+ t.Errorf("Expected error for Map1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ for i := 0; i < 11; i++ {
+ bt.Set1 = append(bt.Set1, "0")
+ }
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Set1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Set1" {
+ t.Errorf("Expected error for Set1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt.Set1 = []string{"0"}
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.min_size error for Set1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.min_size" {
+ t.Errorf("Expected vt.min_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Set1" {
+ t.Errorf("Expected error for Set1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Enum1 = (*validatetest.EnumFoo)(thrift.Int64Ptr(int64(validatetest.EnumFoo_e2)))
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.in error for Enum1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.in" {
+ t.Errorf("Expected vt.in check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Enum1" {
+ t.Errorf("Expected error for Enum1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+}
+
+func TestFieldReference(t *testing.T) {
+ frt := validatetest.NewFieldReferenceTest()
+ if err := frt.Validate(); err != nil {
+ t.Error(err)
+ }
+ var ve *thrift.ValidationError
+ frt = validatetest.NewFieldReferenceTest()
+ frt.Bool2 = true
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Bool0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Bool0" {
+ t.Errorf("Expected error for Bool0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.Byte4 = 9
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Byte0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Byte0" {
+ t.Errorf("Expected error for Byte0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.Double4 = 9
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Double0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Double0" {
+ t.Errorf("Expected error for Double0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.String2 = "other string"
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for String0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "String0" {
+ t.Errorf("Expected error for String0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.Binary2 = []byte("other string")
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Binary0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Binary0" {
+ t.Errorf("Expected error for Binary0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.MaxSize = 8
+ frt.Map0 = make(map[string]string)
+ for i := 0; i < 9; i++ {
+ frt.Map0[strconv.Itoa(i)] = strconv.Itoa(i)
+ }
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map0" {
+ t.Errorf("Expected error for Map0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.MaxSize = 8
+ for i := 0; i < 9; i++ {
+ frt.List0 = append(frt.List0, "0")
+ }
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for List0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "List0" {
+ t.Errorf("Expected error for List0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.MaxSize = 8
+ for i := 0; i < 9; i++ {
+ frt.Set0 = append(frt.Set0, "0")
+ }
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Set0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Set0" {
+ t.Errorf("Expected error for Set0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+}
+
+func TestValidationFunction(t *testing.T) {
+ vft := validatetest.NewValidationFunctionTest()
+ if err := vft.Validate(); err != nil {
+ t.Error(err)
+ }
+ var ve *thrift.ValidationError
+ vft = validatetest.NewValidationFunctionTest()
+ vft.StringFoo = "some string"
+ if err := vft.Validate(); err == nil {
+ t.Errorf("Expected vt.in error for StringLength")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.in" {
+ t.Errorf("Expected vt.in check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "StringLength" {
+ t.Errorf("Expected error for StringLength, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+}
+
+func TestAnnotationCompatibleTest(t *testing.T) {
+ act := validatetest.NewAnnotationCompatibleTest()
+ if err := act.Validate(); err != nil {
+ t.Error(err)
+ }
+ var ve *thrift.ValidationError
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Bool0 = false
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Bool0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Bool0" {
+ t.Errorf("Expected error for Bool0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Byte0 = 3
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Byte0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Byte0" {
+ t.Errorf("Expected error for Byte0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Double0 = 3
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Double0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Double0" {
+ t.Errorf("Expected error for Double0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.String0 = "other string"
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for String0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "String0" {
+ t.Errorf("Expected error for String0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Binary0 = []byte("other string")
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Binary0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Binary0" {
+ t.Errorf("Expected error for Binary0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Map0 = map[string]string{"0": "0", "1": "1", "2": "2"}
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map0" {
+ t.Errorf("Expected error for Map0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Set0 = []string{"0", "1", "2"}
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Set0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Set0" {
+ t.Errorf("Expected error for Set0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.List0 = []string{"0", "1", "2"}
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for List0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "List0" {
+ t.Errorf("Expected error for List0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Enum0 = validatetest.EnumFoo_e1
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.in error for Enum0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.in" {
+ t.Errorf("Expected vt.in check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Enum0" {
+ t.Errorf("Expected error for Enum0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ fields := []string{"bool1", "byte1", "double1", "string1", "binary1", "enum1", "struct1", "list1", "set1", "map1"}
+ b, err := json.Marshal(act)
+ if err != nil {
+ t.Error(err)
+ }
+ jsonMap := make(map[string]interface{})
+ if err = json.Unmarshal(b, &jsonMap); err != nil {
+ t.Error(err)
+ }
+ for _, field := range fields {
+ if _, ok := jsonMap[field]; !ok {
+ t.Errorf("Expected field %s in JSON, but not found", field)
+ }
+ }
+}
diff --git a/lib/go/thrift/application_exception.go b/lib/go/thrift/application_exception.go
index ed85a64..8b8137a 100644
--- a/lib/go/thrift/application_exception.go
+++ b/lib/go/thrift/application_exception.go
@@ -21,6 +21,7 @@
import (
"context"
+ "strings"
)
const (
@@ -35,6 +36,7 @@
INVALID_TRANSFORM = 8
INVALID_PROTOCOL = 9
UNSUPPORTED_CLIENT_TYPE = 10
+ VALIDATION_FAILED = 11
)
var defaultApplicationExceptionMessage = map[int32]string{
@@ -49,6 +51,7 @@
INVALID_TRANSFORM: "Invalid transform",
INVALID_PROTOCOL: "Invalid protocol",
UNSUPPORTED_CLIENT_TYPE: "Unsupported client type",
+ VALIDATION_FAILED: "validation failed",
}
// Application level Thrift exception
@@ -59,9 +62,39 @@
Write(ctx context.Context, oprot TProtocol) error
}
+type ValidationError struct {
+ message string
+ check string
+ fieldSymbol string
+}
+
+func (e *ValidationError) Check() string {
+ return e.check
+}
+
+func (e *ValidationError) TypeName() string {
+ return strings.Split(e.fieldSymbol, ".")[0]
+}
+
+func (e *ValidationError) Field() string {
+ if fs := strings.Split(e.fieldSymbol, "."); len(fs) > 1 {
+ return fs[1]
+ }
+ return e.fieldSymbol
+}
+
+func (e *ValidationError) FieldSymbol() string {
+ return e.fieldSymbol
+}
+
+func (e ValidationError) Error() string {
+ return e.message
+}
+
type tApplicationException struct {
message string
type_ int32
+ err error
}
var _ TApplicationException = (*tApplicationException)(nil)
@@ -77,8 +110,20 @@
return defaultApplicationExceptionMessage[e.type_]
}
+func (e tApplicationException) Unwrap() error {
+ return e.err
+}
+
func NewTApplicationException(type_ int32, message string) TApplicationException {
- return &tApplicationException{message, type_}
+ return &tApplicationException{message, type_, nil}
+}
+
+func NewValidationException(type_ int32, check string, field string, message string) TApplicationException {
+ return &tApplicationException{
+ type_: type_,
+ message: message,
+ err: &ValidationError{message: message, check: check, fieldSymbol: field},
+ }
}
func (p *tApplicationException) TypeId() int32 {