diff --git a/request.go b/request.go index b3c96a6..2d83aa3 100644 --- a/request.go +++ b/request.go @@ -34,7 +34,8 @@ func prepareRequests(command string, params []string) (requests []RawRequest, er } switch len(params) { case 0: - if len(command) >= BUF_SIZE { + // +2 because of the 'j/' + if len(command)+2 > BUF_SIZE { return nil, fmt.Errorf( "command is too long (%d>=%d): %s", len(command), @@ -45,7 +46,8 @@ func prepareRequests(command string, params []string) (requests []RawRequest, er requests = append(requests, []byte(command)) case 1: request := command + " " + params[0] - if len(request) >= BUF_SIZE { + // +2 because of the 'j/' + if len(request)+2 > BUF_SIZE { return nil, fmt.Errorf( "command is too long (%d>=%d): %s", len(request), @@ -65,11 +67,12 @@ func prepareRequests(command string, params []string) (requests []RawRequest, er for _, param := range params { // Get the current command + param length - cmdLen := len(command) + len(param) + 2 // ; + - if len(batch)+cmdLen >= BUF_SIZE { - // If batch + command + param length is bigger - // than BUF_SIZE, return an error since it will - // not fit the socket + // +4 because of the 'j/; ' + cmdLen := len(command) + len(param) + 4 + if len(batch)+cmdLen > BUF_SIZE { + // If batch + command length is bigger than + // BUF_SIZE, return an error since it will not + // fit the socket return nil, fmt.Errorf( "command is too long (%d>=%d): %s%s %s;", len(batch)+cmdLen, @@ -78,7 +81,7 @@ func prepareRequests(command string, params []string) (requests []RawRequest, er command, param, ) - } else if curLen+cmdLen < BUF_SIZE { + } else if curLen+cmdLen <= BUF_SIZE { // If the current length of the buffer + // command + param is less than BUF_SIZE, the // request will fit @@ -229,6 +232,14 @@ func (c *RequestClient) RawRequest(request RawRequest) (response RawResponse, er // Send the request to the socket request = append([]byte{'j', '/'}, request...) + if len(request) > BUF_SIZE { + return nil, fmt.Errorf( + "request too big (%d>%d): %s", + len(request), + BUF_SIZE, + request, + ) + } _, err = conn.Write(request) if err != nil { return nil, fmt.Errorf("error while writing to socket: %w", err) diff --git a/request_test.go b/request_test.go index 45cdfd5..2be0fd2 100644 --- a/request_test.go +++ b/request_test.go @@ -112,19 +112,19 @@ func TestPrepareRequestsMass(t *testing.T) { func TestPrepareRequestsError(t *testing.T) { _, err := prepareRequests( - strings.Repeat("c", BUF_SIZE), + strings.Repeat("c", BUF_SIZE-1), nil, ) assert.Error(t, err) _, err = prepareRequests( - strings.Repeat("c", BUF_SIZE-len("p ")), + strings.Repeat("c", BUF_SIZE-len("p ")-1), genParams("p", 1), ) assert.Error(t, err) _, err = prepareRequests( - strings.Repeat("c", BUF_SIZE-len("[[BATCH]]"+"p ;")), + strings.Repeat("c", BUF_SIZE-len("[[BATCH]] p;")-1), genParams("p", 5), ) assert.Error(t, err)