blob: cd99d58de58101a53a233f50a995c057425cdae2 [file] [log] [blame]
Philippe Antoine65ea7522021-03-15 09:34:58 +01001// +build gofuzz
2
3/*
4 * Licensed to the Apache Software Foundation (ASF) under one
5 * or more contributor license agreements. See the NOTICE file
6 * distributed with this work for additional information
7 * regarding copyright ownership. The ASF licenses this file
8 * to you under the Apache License, Version 2.0 (the
9 * "License"); you may not use this file except in compliance
10 * with the License. You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing,
15 * software distributed under the License is distributed on an
16 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 * KIND, either express or implied. See the License for the
18 * specific language governing permissions and limitations
19 * under the License.
20 */
21
22package fuzz
23
24import (
25 "context"
26 "fmt"
Philippe Antoine65ea7522021-03-15 09:34:58 +010027 "strconv"
Philippe Antoine65ea7522021-03-15 09:34:58 +010028
Yuxuan 'fishy' Wang3761f002021-03-25 15:41:53 -070029 "shared"
30 "tutorial"
31
Philippe Antoine65ea7522021-03-15 09:34:58 +010032 "github.com/apache/thrift/lib/go/thrift"
33)
34
35const nbFuzzedProtocols = 2
36
37func fuzzChooseProtocol(d byte, t thrift.TTransport) thrift.TProtocol {
38 switch d % nbFuzzedProtocols {
39 default:
40 fallthrough
41 case 0:
42 return thrift.NewTBinaryProtocolFactoryConf(nil).GetProtocol(t)
43 case 1:
44 return thrift.NewTCompactProtocolFactoryConf(nil).GetProtocol(t)
45 }
46}
47
48func Fuzz(data []byte) int {
49 if len(data) < 2 {
50 return 0
51 }
52 inputTransport := thrift.NewTMemoryBuffer()
53 inputTransport.Buffer.Write(data[2:])
54 outputTransport := thrift.NewTMemoryBuffer()
55 outputProtocol := fuzzChooseProtocol(data[0], outputTransport)
56 inputProtocol := fuzzChooseProtocol(data[1], inputTransport)
57 ctx := thrift.SetResponseHelper(
58 context.Background(),
59 thrift.TResponseHelper{
60 THeaderResponseHelper: thrift.NewTHeaderResponseHelper(outputProtocol),
61 },
62 )
63 handler := NewCalculatorHandler()
64 processor := tutorial.NewCalculatorProcessor(handler)
65 ok := true
66 var err error
67 for ok {
68 ok, err = processor.Process(ctx, inputProtocol, outputProtocol)
69 if err != nil {
70 // Handle parse error
71 return 0
72 }
73 res := make([]byte, 1024)
74 n, err := outputTransport.Buffer.Read(res)
75 fmt.Printf("lol %d %s %v\n", n, err, res)
76 }
77 return 1
78}
79
80type CalculatorHandler struct {
81 log map[int]*shared.SharedStruct
82}
83
84func NewCalculatorHandler() *CalculatorHandler {
85 return &CalculatorHandler{log: make(map[int]*shared.SharedStruct)}
86}
87
88func (p *CalculatorHandler) Ping(ctx context.Context) (err error) {
89 fmt.Print("ping()\n")
90 return nil
91}
92
93func (p *CalculatorHandler) Add(ctx context.Context, num1 int32, num2 int32) (retval17 int32, err error) {
94 fmt.Print("add(", num1, ",", num2, ")\n")
95 return num1 + num2, nil
96}
97
98func (p *CalculatorHandler) Calculate(ctx context.Context, logid int32, w *tutorial.Work) (val int32, err error) {
99 fmt.Print("calculate(", logid, ", {", w.Op, ",", w.Num1, ",", w.Num2, "})\n")
100 switch w.Op {
101 case tutorial.Operation_ADD:
102 val = w.Num1 + w.Num2
103 break
104 case tutorial.Operation_SUBTRACT:
105 val = w.Num1 - w.Num2
106 break
107 case tutorial.Operation_MULTIPLY:
108 val = w.Num1 * w.Num2
109 break
110 case tutorial.Operation_DIVIDE:
111 if w.Num2 == 0 {
112 ouch := tutorial.NewInvalidOperation()
113 ouch.WhatOp = int32(w.Op)
114 ouch.Why = "Cannot divide by 0"
115 err = ouch
116 return
117 }
118 val = w.Num1 / w.Num2
119 break
120 default:
121 ouch := tutorial.NewInvalidOperation()
122 ouch.WhatOp = int32(w.Op)
123 ouch.Why = "Unknown operation"
124 err = ouch
125 return
126 }
127 entry := shared.NewSharedStruct()
128 entry.Key = logid
129 entry.Value = strconv.Itoa(int(val))
130 k := int(logid)
131 p.log[k] = entry
132 return val, err
133}
134
135func (p *CalculatorHandler) GetStruct(ctx context.Context, key int32) (*shared.SharedStruct, error) {
136 fmt.Print("getStruct(", key, ")\n")
137 v, _ := p.log[int(key)]
138 return v, nil
139}
140
141func (p *CalculatorHandler) Zip(ctx context.Context) (err error) {
142 fmt.Print("zip()\n")
143 return nil
144}