Skip to content

Commit

Permalink
fix tinyint marshal, unmarshal
Browse files Browse the repository at this point in the history
  • Loading branch information
illia-li committed Sep 28, 2024
1 parent 25c97f8 commit b18ede6
Show file tree
Hide file tree
Showing 5 changed files with 732 additions and 96 deletions.
132 changes: 36 additions & 96 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/gocql/gocql/marshal/tinyint"
"math"
"math/big"
"math/bits"
Expand All @@ -30,14 +31,33 @@ var (
ErrorUDTUnavailable = errors.New("UDT are not available on protocols less than 3, please update config")
)

// Marshaler is the interface implemented by objects that can marshal
// themselves into values understood by Cassandra.
// Marshaler is an interface for custom unmarshaler.
// Each value of the 'CQL binary protocol' consist of <value_len> and <value_data>.
// <value_len> can be 'unset'(-2), 'nil'(-1), 'zero'(0) or any value up to 2147483647.
// When <value_len> is 'unset', 'nil' or 'zero', <value_data> is not present.
// 'unset' is applicable only to columns, with some exceptions.
// As you can see from API MarshalCQL only returns <value_data>, but there is a way for it to control <value_len>:
// 1. If MarshalCQL returns (gocql.UnsetValue, nil), gocql writes 'unset' to <value_len>
// 2. If MarshalCQL returns ([]byte(nil), nil), gocql writes 'nil' to <value_len>
// 3. If MarshalCQL returns ([]byte{}, nil), gocql writes 'zero' to <value_len>
//
// Some CQL databases have proprietary value coding features, which you may want to consider.
// CQL binary protocol info:https://github.com/apache/cassandra/tree/trunk/doc
type Marshaler interface {
MarshalCQL(info TypeInfo) ([]byte, error)
}

// Unmarshaler is the interface implemented by objects that can unmarshal
// a Cassandra specific description of themselves.
// Unmarshaler is an interface for custom unmarshaler.
// Each value of the 'CQL binary protocol' consist of <value_len> and <value_data>.
// <value_len> can be 'unset'(-2), 'nil'(-1), 'zero'(0) or any value up to 2147483647.
// When <value_len> is 'unset', 'nil' or 'zero', <value_data> is not present.
// As you can see from an API UnmarshalCQL receives only 'info TypeInfo' and
// 'data []byte', but gocql has the following way to signal about <value_len>:
// 1. When <value_len> is 'nil' gocql feeds nil to 'data []byte'
// 2. When <value_len> is 'zero' gocql feeds []byte{} to 'data []byte'
//
// Some CQL databases have proprietary value coding features, which you may want to consider.
// CQL binary protocol info:https://github.com/apache/cassandra/tree/trunk/doc
type Unmarshaler interface {
UnmarshalCQL(info TypeInfo, data []byte) error
}
Expand Down Expand Up @@ -115,7 +135,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeBoolean:
return marshalBool(info, value)
case TypeTinyInt:
return marshalTinyInt(info, value)
return marshalTinyInt(value)
case TypeSmallInt:
return marshalSmallInt(info, value)
case TypeInt:
Expand Down Expand Up @@ -225,7 +245,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeSmallInt:
return unmarshalSmallInt(info, data, value)
case TypeTinyInt:
return unmarshalTinyInt(info, data, value)
return unmarshalTinyInt(data, value)
case TypeFloat:
return unmarshalFloat(info, data, value)
case TypeDouble:
Expand Down Expand Up @@ -438,88 +458,12 @@ func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) {
return nil, marshalErrorf("can not marshal %T into %s", value, info)
}

func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case int8:
return []byte{byte(v)}, nil
case uint8:
return []byte{byte(v)}, nil
case int16:
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case uint16:
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case int:
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case int32:
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case int64:
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case uint:
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case uint32:
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case uint64:
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case string:
n, err := strconv.ParseInt(v, 10, 8)
if err != nil {
return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err)
}
return []byte{byte(n)}, nil
}

if value == nil {
return nil, nil
}

switch rv := reflect.ValueOf(value); rv.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
v := rv.Int()
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
v := rv.Uint()
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case reflect.Ptr:
if rv.IsNil() {
return nil, nil
}
func marshalTinyInt(value interface{}) ([]byte, error) {
data, err := tinyint.Marshal(value)
if err != nil {
return nil, MarshalError(err.(Error).Error())
}

return nil, marshalErrorf("can not marshal %T into %s", value, info)
return data, nil
}

func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
Expand Down Expand Up @@ -619,13 +563,6 @@ func decShort(p []byte) int16 {
return int16(p[0])<<8 | int16(p[1])
}

func decTiny(p []byte) int8 {
if len(p) != 1 {
return 0
}
return int8(p[0])
}

func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
Expand Down Expand Up @@ -715,8 +652,11 @@ func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error {
return unmarshalIntlike(info, int64(decShort(data)), data, value)
}

func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error {
return unmarshalIntlike(info, int64(decTiny(data)), data, value)
func unmarshalTinyInt(data []byte, value interface{}) error {
if err := tinyint.Unmarshal(data, value); err != nil {
return UnmarshalError(err.(Error).Error())
}
return nil
}

func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error {
Expand Down
74 changes: 74 additions & 0 deletions marshal/tinyint/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package tinyint

import (
"math/big"
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int8:
return EncInt8(v)
case int32:
return EncInt32(v)
case int16:
return EncInt16(v)
case int64:
return EncInt64(v)
case int:
return EncInt(v)

case uint8:
return EncUint8(v)
case uint16:
return EncUint16(v)
case uint32:
return EncUint32(v)
case uint64:
return EncUint64(v)
case uint:
return EncUint(v)

case big.Int:
return EncBigInt(v)
case string:
return EncString(v)

case *int8:
return EncInt8R(v)
case *int16:
return EncInt16R(v)
case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *int:
return EncIntR(v)

case *uint8:
return EncUint8R(v)
case *uint16:
return EncUint16R(v)
case *uint32:
return EncUint32R(v)
case *uint64:
return EncUint64R(v)
case *uint:
return EncUintR(v)

case *big.Int:
return EncBigIntR(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyInt int) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
Loading

0 comments on commit b18ede6

Please sign in to comment.