From c8b2212271fd19748f3108b4eee0cb2cb48359ea Mon Sep 17 00:00:00 2001 From: illia-li Date: Sat, 14 Sep 2024 20:54:53 -0400 Subject: [PATCH 1/2] fix `tinyint` marshal, unmarshal --- marshal.go | 149 +++++-------- marshal/tinyint/marshal.go | 74 +++++++ marshal/tinyint/marshal_utils.go | 201 +++++++++++++++++ marshal/tinyint/unmarshal.go | 81 +++++++ marshal/tinyint/unmarshal_utils.go | 334 +++++++++++++++++++++++++++++ 5 files changed, 743 insertions(+), 96 deletions(-) create mode 100644 marshal/tinyint/marshal.go create mode 100644 marshal/tinyint/marshal_utils.go create mode 100644 marshal/tinyint/unmarshal.go create mode 100644 marshal/tinyint/unmarshal_utils.go diff --git a/marshal.go b/marshal.go index 6ae1546ea..743061acd 100644 --- a/marshal.go +++ b/marshal.go @@ -19,6 +19,8 @@ import ( "time" "gopkg.in/inf.v0" + + "github.com/gocql/gocql/marshal/tinyint" ) var ( @@ -30,14 +32,33 @@ var ( ErrorUDTUnavailable = errors.New("UDT are not available on protocols less than 3, please update config") ) -// Marshaler is the interface implemented by objects that can marshal -// themselves into values understood by Cassandra. +// Marshaler is an interface for custom unmarshaler. +// Each value of the 'CQL binary protocol' consist of and . +// can be 'unset'(-2), 'nil'(-1), 'zero'(0) or any value up to 2147483647. +// When is 'unset', 'nil' or 'zero', is not present. +// 'unset' is applicable only to columns, with some exceptions. +// As you can see from API MarshalCQL only returns , but there is a way for it to control : +// 1. If MarshalCQL returns (gocql.UnsetValue, nil), gocql writes 'unset' to +// 2. If MarshalCQL returns ([]byte(nil), nil), gocql writes 'nil' to +// 3. If MarshalCQL returns ([]byte{}, nil), gocql writes 'zero' to +// +// Some CQL databases have proprietary value coding features, which you may want to consider. +// CQL binary protocol info:https://github.com/apache/cassandra/tree/trunk/doc type Marshaler interface { MarshalCQL(info TypeInfo) ([]byte, error) } -// Unmarshaler is the interface implemented by objects that can unmarshal -// a Cassandra specific description of themselves. +// Unmarshaler is an interface for custom unmarshaler. +// Each value of the 'CQL binary protocol' consist of and . +// can be 'unset'(-2), 'nil'(-1), 'zero'(0) or any value up to 2147483647. +// When is 'unset', 'nil' or 'zero', is not present. +// As you can see from an API UnmarshalCQL receives only 'info TypeInfo' and +// 'data []byte', but gocql has the following way to signal about : +// 1. When is 'nil' gocql feeds nil to 'data []byte' +// 2. When is 'zero' gocql feeds []byte{} to 'data []byte' +// +// Some CQL databases have proprietary value coding features, which you may want to consider. +// CQL binary protocol info:https://github.com/apache/cassandra/tree/trunk/doc type Unmarshaler interface { UnmarshalCQL(info TypeInfo, data []byte) error } @@ -115,7 +136,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeBoolean: return marshalBool(info, value) case TypeTinyInt: - return marshalTinyInt(info, value) + return marshalTinyInt(value) case TypeSmallInt: return marshalSmallInt(info, value) case TypeInt: @@ -225,7 +246,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { case TypeSmallInt: return unmarshalSmallInt(info, data, value) case TypeTinyInt: - return unmarshalTinyInt(info, data, value) + return unmarshalTinyInt(data, value) case TypeFloat: return unmarshalFloat(info, data, value) case TypeDouble: @@ -438,88 +459,12 @@ func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int8: - return []byte{byte(v)}, nil - case uint8: - return []byte{byte(v)}, nil - case int16: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint16: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int32: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int64: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint32: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint64: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case string: - n, err := strconv.ParseInt(v, 10, 8) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) - } - return []byte{byte(n)}, nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } +func marshalTinyInt(value interface{}) ([]byte, error) { + data, err := tinyint.Marshal(value) + if err != nil { + return nil, wrapMarshalError(err, "marshal error") } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return data, nil } func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { @@ -619,13 +564,6 @@ func decShort(p []byte) int16 { return int16(p[0])<<8 | int16(p[1]) } -func decTiny(p []byte) int8 { - if len(p) != 1 { - return 0 - } - return int8(p[0]) -} - func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: @@ -715,8 +653,11 @@ func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, int64(decShort(data)), data, value) } -func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decTiny(data)), data, value) +func unmarshalTinyInt(data []byte, value interface{}) error { + if err := tinyint.Unmarshal(data, value); err != nil { + return wrapUnmarshalError(err, "unmarshal error") + } + return nil } func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { @@ -2734,6 +2675,14 @@ func marshalErrorf(format string, args ...interface{}) MarshalError { return MarshalError{msg: fmt.Sprintf(format, args...)} } +func wrapMarshalError(err error, msg string) MarshalError { + return MarshalError{msg: msg, cause: err} +} + +func wrapMarshalErrorf(err error, format string, a ...interface{}) MarshalError { + return MarshalError{msg: fmt.Sprintf(format, a...), cause: err} +} + type UnmarshalError struct { cause error msg string @@ -2755,3 +2704,11 @@ func (m UnmarshalError) Unwrap() error { func unmarshalErrorf(format string, args ...interface{}) UnmarshalError { return UnmarshalError{msg: fmt.Sprintf(format, args...)} } + +func wrapUnmarshalError(err error, msg string) UnmarshalError { + return UnmarshalError{msg: msg, cause: err} +} + +func wrapUnmarshalErrorf(err error, format string, a ...interface{}) UnmarshalError { + return UnmarshalError{msg: fmt.Sprintf(format, a...), cause: err} +} diff --git a/marshal/tinyint/marshal.go b/marshal/tinyint/marshal.go new file mode 100644 index 000000000..1cfb85c7d --- /dev/null +++ b/marshal/tinyint/marshal.go @@ -0,0 +1,74 @@ +package tinyint + +import ( + "math/big" + "reflect" +) + +func Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case nil: + return nil, nil + case int8: + return EncInt8(v) + case int32: + return EncInt32(v) + case int16: + return EncInt16(v) + case int64: + return EncInt64(v) + case int: + return EncInt(v) + + case uint8: + return EncUint8(v) + case uint16: + return EncUint16(v) + case uint32: + return EncUint32(v) + case uint64: + return EncUint64(v) + case uint: + return EncUint(v) + + case big.Int: + return EncBigInt(v) + case string: + return EncString(v) + + case *int8: + return EncInt8R(v) + case *int16: + return EncInt16R(v) + case *int32: + return EncInt32R(v) + case *int64: + return EncInt64R(v) + case *int: + return EncIntR(v) + + case *uint8: + return EncUint8R(v) + case *uint16: + return EncUint16R(v) + case *uint32: + return EncUint32R(v) + case *uint64: + return EncUint64R(v) + case *uint: + return EncUintR(v) + + case *big.Int: + return EncBigIntR(v) + case *string: + return EncStringR(v) + default: + // Custom types (type MyInt int) can be serialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.TypeOf(value) + if rv.Kind() != reflect.Ptr { + return EncReflect(reflect.ValueOf(v)) + } + return EncReflectR(reflect.ValueOf(v)) + } +} diff --git a/marshal/tinyint/marshal_utils.go b/marshal/tinyint/marshal_utils.go new file mode 100644 index 000000000..e204281f4 --- /dev/null +++ b/marshal/tinyint/marshal_utils.go @@ -0,0 +1,201 @@ +package tinyint + +import ( + "fmt" + "math" + "math/big" + "reflect" + "strconv" +) + +var ( + maxBigInt = big.NewInt(math.MaxInt8) + minBigInt = big.NewInt(math.MinInt8) +) + +func EncInt8(v int8) ([]byte, error) { + return []byte{byte(v)}, nil +} + +func EncInt8R(v *int8) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncInt8(*v) +} + +func EncInt16(v int16) ([]byte, error) { + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("failed to marshal tinyint: value %#v out of range", v) + } + return []byte{byte(v)}, nil +} + +func EncInt16R(v *int16) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncInt16(*v) +} + +func EncInt32(v int32) ([]byte, error) { + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("failed to marshal tinyint: value %#v out of range", v) + } + return []byte{byte(v)}, nil +} + +func EncInt32R(v *int32) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncInt32(*v) +} + +func EncInt64(v int64) ([]byte, error) { + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("failed to marshal tinyint: value %#v out of range", v) + } + return []byte{byte(v)}, nil +} + +func EncInt64R(v *int64) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncInt64(*v) +} + +func EncInt(v int) ([]byte, error) { + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("failed to marshal tinyint: value %#v out of range", v) + } + return []byte{byte(v)}, nil +} + +func EncIntR(v *int) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncInt(*v) +} + +func EncUint8(v uint8) ([]byte, error) { + return []byte{v}, nil +} + +func EncUint8R(v *uint8) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncUint8(*v) +} + +func EncUint16(v uint16) ([]byte, error) { + if v > math.MaxUint8 { + return nil, fmt.Errorf("failed to marshal tinyint: value %#v out of range", v) + } + return []byte{byte(v)}, nil +} + +func EncUint16R(v *uint16) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncUint16(*v) +} + +func EncUint32(v uint32) ([]byte, error) { + if v > math.MaxUint8 { + return nil, fmt.Errorf("failed to marshal tinyint: value %#v out of range", v) + } + return []byte{byte(v)}, nil +} + +func EncUint32R(v *uint32) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncUint32(*v) +} + +func EncUint64(v uint64) ([]byte, error) { + if v > math.MaxUint8 { + return nil, fmt.Errorf("failed to marshal tinyint: value %#v out of range", v) + } + return []byte{byte(v)}, nil +} + +func EncUint64R(v *uint64) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncUint64(*v) +} + +func EncUint(v uint) ([]byte, error) { + if v > math.MaxUint8 { + return nil, fmt.Errorf("failed to marshal tinyint: value %#v out of range", v) + } + return []byte{byte(v)}, nil +} + +func EncUintR(v *uint) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncUint(*v) +} + +func EncBigInt(v big.Int) ([]byte, error) { + if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 { + return nil, fmt.Errorf("failed to marshal tinyint: value (%T)(%s) out of range", v, v.String()) + } + return []byte{byte(v.Int64())}, nil +} + +func EncBigIntR(v *big.Int) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncBigInt(*v) +} + +func EncString(v string) ([]byte, error) { + if v == "" { + return nil, nil + } + + n, err := strconv.ParseInt(v, 10, 8) + if err != nil { + return nil, fmt.Errorf("failed to marshal tinyint: can not marshal %#v %s", v, err) + } + return []byte{byte(n)}, nil +} + +func EncStringR(v *string) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncString(*v) +} + +func EncReflect(v reflect.Value) ([]byte, error) { + switch v.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + return EncInt64(v.Int()) + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + return EncUint64(v.Uint()) + case reflect.String: + return EncString(v.String()) + default: + return nil, fmt.Errorf("failed to marshal tinyint: unsupported value type (%T)(%#[1]v)", v.Interface()) + } +} + +func EncReflectR(v reflect.Value) ([]byte, error) { + if v.IsNil() { + return nil, nil + } + return EncReflect(v.Elem()) +} diff --git a/marshal/tinyint/unmarshal.go b/marshal/tinyint/unmarshal.go new file mode 100644 index 000000000..4a983e7c1 --- /dev/null +++ b/marshal/tinyint/unmarshal.go @@ -0,0 +1,81 @@ +package tinyint + +import ( + "fmt" + "math/big" + "reflect" +) + +func Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + + case *int8: + return DecInt8(data, v) + case *int16: + return DecInt16(data, v) + case *int32: + return DecInt32(data, v) + case *int64: + return DecInt64(data, v) + case *int: + return DecInt(data, v) + + case *uint8: + return DecUint8(data, v) + case *uint16: + return DecUint16(data, v) + case *uint32: + return DecUint32(data, v) + case *uint64: + return DecUint64(data, v) + case *uint: + return DecUint(data, v) + + case *big.Int: + return DecBigInt(data, v) + case *string: + return DecString(data, v) + + case **int8: + return DecInt8R(data, v) + case **int16: + return DecInt16R(data, v) + case **int32: + return DecInt32R(data, v) + case **int64: + return DecInt64R(data, v) + case **int: + return DecIntR(data, v) + + case **uint8: + return DecUint8R(data, v) + case **uint16: + return DecUint16R(data, v) + case **uint32: + return DecUint32R(data, v) + case **uint64: + return DecUint64R(data, v) + case **uint: + return DecUintR(data, v) + + case **big.Int: + return DecBigIntR(data, v) + case **string: + return DecStringR(data, v) + default: + + // Custom types (type MyInt int) can be deserialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.ValueOf(value) + rt := rv.Type() + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("failed to unmarshal tinyint: unsupported value type (%T)(%#[1]v)", value) + } + if rt.Elem().Kind() != reflect.Ptr { + return DecReflect(data, rv) + } + return DecReflectR(data, rv) + } +} diff --git a/marshal/tinyint/unmarshal_utils.go b/marshal/tinyint/unmarshal_utils.go new file mode 100644 index 000000000..55b301702 --- /dev/null +++ b/marshal/tinyint/unmarshal_utils.go @@ -0,0 +1,334 @@ +package tinyint + +import ( + "fmt" + "math/big" + "reflect" + "strconv" +) + +var errWrongDataLen = fmt.Errorf("failed to unmarshal tinyint: the length of the data should less or equal then 1") + +func DecInt8(p []byte, v *int8) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = int8(p[0]) + default: + return errWrongDataLen + } + return nil +} + +func DecInt8R(p []byte, v **int8) error { + if p != nil { + *v = new(int8) + return DecInt8(p, *v) + } + *v = nil + return nil +} + +func DecInt16(p []byte, v *int16) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = int16(int8(p[0])) + default: + return errWrongDataLen + } + return nil +} + +func DecInt16R(p []byte, v **int16) error { + if p != nil { + *v = new(int16) + return DecInt16(p, *v) + } + *v = nil + return nil +} + +func DecInt32(p []byte, v *int32) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = int32(int8(p[0])) + default: + return errWrongDataLen + } + return nil +} + +func DecInt32R(p []byte, v **int32) error { + if p != nil { + *v = new(int32) + return DecInt32(p, *v) + } + *v = nil + return nil +} + +func DecInt64(p []byte, v *int64) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = int64(int8(p[0])) + default: + return errWrongDataLen + } + return nil +} + +func DecInt64R(p []byte, v **int64) error { + if p != nil { + *v = new(int64) + return DecInt64(p, *v) + } + *v = nil + return nil +} + +func DecInt(p []byte, v *int) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = int(int8(p[0])) + default: + return errWrongDataLen + } + return nil +} + +func DecIntR(p []byte, v **int) error { + if p != nil { + *v = new(int) + return DecInt(p, *v) + } + *v = nil + return nil +} + +func DecUint8(p []byte, v *uint8) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = p[0] + default: + return errWrongDataLen + } + return nil +} + +func DecUint8R(p []byte, v **uint8) error { + if p != nil { + *v = new(uint8) + return DecUint8(p, *v) + } + *v = nil + return nil +} + +func DecUint16(p []byte, v *uint16) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = uint16(p[0]) + default: + return errWrongDataLen + } + return nil +} + +func DecUint16R(p []byte, v **uint16) error { + if p != nil { + *v = new(uint16) + return DecUint16(p, *v) + } + *v = nil + return nil +} + +func DecUint32(p []byte, v *uint32) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = uint32(p[0]) + default: + return errWrongDataLen + } + return nil +} + +func DecUint32R(p []byte, v **uint32) error { + if p != nil { + *v = new(uint32) + return DecUint32(p, *v) + } + *v = nil + return nil +} + +func DecUint64(p []byte, v *uint64) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = uint64(p[0]) + default: + return errWrongDataLen + } + return nil +} + +func DecUint64R(p []byte, v **uint64) error { + if p != nil { + *v = new(uint64) + return DecUint64(p, *v) + } + *v = nil + return nil +} + +func DecUint(p []byte, v *uint) error { + switch len(p) { + case 0: + *v = 0 + case 1: + *v = uint(p[0]) + default: + return errWrongDataLen + } + return nil +} + +func DecUintR(p []byte, v **uint) error { + if p != nil { + *v = new(uint) + return DecUint(p, *v) + } + *v = nil + return nil +} + +func DecString(p []byte, v *string) error { + switch len(p) { + case 0: + if p != nil { + *v = "0" + } else { + *v = "" + } + case 1: + *v = strconv.FormatInt(int64(int8(p[0])), 10) + default: + return errWrongDataLen + } + return nil +} + +func DecStringR(p []byte, v **string) error { + if p != nil { + *v = new(string) + return DecString(p, *v) + } + *v = nil + return nil +} + +func DecBigInt(p []byte, v *big.Int) error { + switch len(p) { + case 0: + v.SetInt64(0) + case 1: + v.SetInt64(int64(int8(p[0]))) + default: + return errWrongDataLen + } + return nil +} + +func DecBigIntR(p []byte, v **big.Int) error { + if p != nil { + *v = big.NewInt(0) + return DecBigInt(p, *v) + } + *v = nil + return nil +} + +func DecReflect(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal tinyint: can not unmarshal into nil reference (%T)(%#[1]v)", v.Interface()) + } + + switch v = v.Elem(); v.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + return decReflectInts(p, v) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return decReflectUints(p, v) + case reflect.String: + return decReflectString(p, v) + default: + return fmt.Errorf("failed to unmarshal tinyint: unsupported value type (%T)(%#[1]v)", v.Interface()) + } +} + +func DecReflectR(p []byte, v reflect.Value) error { + if p != nil { + zeroValue := reflect.New(v.Type().Elem().Elem()) + v.Elem().Set(zeroValue) + return DecReflect(p, v.Elem()) + } + nilValue := reflect.Zero(v.Elem().Type()) + v.Elem().Set(nilValue) + return nil +} + +func decReflectInts(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetInt(0) + case 1: + v.SetInt(int64(int8(p[0]))) + default: + return errWrongDataLen + } + return nil +} + +func decReflectUints(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetUint(0) + case 1: + v.SetUint(uint64(p[0])) + default: + return errWrongDataLen + } + return nil +} + +func decReflectString(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + if p != nil { + v.SetString("0") + } else { + v.SetString("") + } + case 1: + v.SetString(strconv.FormatInt(int64(int8(p[0])), 10)) + default: + return errWrongDataLen + } + return nil +} From 89a03acc80ea479dc20072c34c6ed04b171076ef Mon Sep 17 00:00:00 2001 From: illia-li Date: Tue, 1 Oct 2024 08:35:55 -0400 Subject: [PATCH 2/2] add tests --- marshal_2_tinyint_corrupt_test.go | 69 +++++++++++++++ marshal_2_tinyint_test.go | 139 ++++++++++++++++-------------- 2 files changed, 141 insertions(+), 67 deletions(-) create mode 100644 marshal_2_tinyint_corrupt_test.go diff --git a/marshal_2_tinyint_corrupt_test.go b/marshal_2_tinyint_corrupt_test.go new file mode 100644 index 000000000..680f44a61 --- /dev/null +++ b/marshal_2_tinyint_corrupt_test.go @@ -0,0 +1,69 @@ +package gocql_test + +import ( + "math/big" + "testing" + + "github.com/gocql/gocql" + "github.com/gocql/gocql/internal/tests/serialization" + "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/marshal/tinyint" +) + +func TestMarshalTinyintCorrupt(t *testing.T) { + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error + } + + tType := gocql.NewNativeType(4, gocql.TypeTinyInt, "") + + testSuites := [2]testSuite{ + { + name: "serialization.tinyint", + marshal: tinyint.Marshal, + unmarshal: tinyint.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, + } + + for _, tSuite := range testSuites { + marshal := tSuite.marshal + unmarshal := tSuite.unmarshal + + t.Run(tSuite.name, func(t *testing.T) { + + serialization.NegativeMarshalSet{ + Values: mod.Values{ + int16(128), int32(128), int64(128), int(128), + "128", *big.NewInt(128), + int16(-129), int32(-129), int64(-129), int(-129), + "-129", *big.NewInt(-129), + uint16(256), uint32(256), uint64(256), uint(256), + }.AddVariants(mod.All...), + }.Run("big_vals", t, marshal) + + serialization.NegativeMarshalSet{ + Values: mod.Values{"1s2", "1s", "-1s", ".1", ",1", "0.1", "0,1"}.AddVariants(mod.All...), + }.Run("corrupt_vals", t, marshal) + + serialization.NegativeUnmarshalSet{ + Data: []byte("\x80\x00"), + Values: mod.Values{ + int8(0), int16(0), int32(0), int64(0), int(0), + uint8(0), uint16(0), uint32(0), uint64(0), uint(0), + "", *big.NewInt(0), + }.AddVariants(mod.All...), + }.Run("big_data", t, unmarshal) + }) + } +} diff --git a/marshal_2_tinyint_test.go b/marshal_2_tinyint_test.go index 80ae976c1..0ce760d2d 100644 --- a/marshal_2_tinyint_test.go +++ b/marshal_2_tinyint_test.go @@ -7,85 +7,90 @@ import ( "github.com/gocql/gocql" "github.com/gocql/gocql/internal/tests/serialization" "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/marshal/tinyint" ) func TestMarshalTinyint(t *testing.T) { - tType := gocql.NewNativeType(4, gocql.TypeTinyInt, "") - - marshal := func(i interface{}) ([]byte, error) { return gocql.Marshal(tType, i) } - unmarshal := func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(tType, bytes, i) + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error } - // unmarshal `custom string` unsupported - brokenCustomStrings := serialization.GetTypes(mod.String(""), (*mod.String)(nil)) + tType := gocql.NewNativeType(4, gocql.TypeTinyInt, "") - // marshal "" (empty string) unsupported - // unmarshal nil value into (string)("0") - brokenEmptyStrings := serialization.GetTypes(string(""), mod.String("")) + testSuites := [2]testSuite{ + { + name: "serialization.tinyint", + marshal: tinyint.Marshal, + unmarshal: tinyint.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, + } - // marshal `custom string` unsupported - // marshal `big.Int` unsupported - brokenMarshalTypes := append(brokenCustomStrings, serialization.GetTypes(big.Int{}, &big.Int{})...) + for _, tSuite := range testSuites { + marshal := tSuite.marshal + unmarshal := tSuite.unmarshal - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{ - (*int8)(nil), (*int16)(nil), (*int32)(nil), (*int64)(nil), (*int)(nil), - (*uint8)(nil), (*uint16)(nil), (*uint32)(nil), (*uint64)(nil), (*uint)(nil), - (*string)(nil), (*big.Int)(nil), string(""), - }.AddVariants(mod.CustomType), - BrokenMarshalTypes: brokenEmptyStrings, - BrokenUnmarshalTypes: brokenEmptyStrings, - }.Run("[nil]nullable", t, marshal, unmarshal) + t.Run(tSuite.name, func(t *testing.T) { + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{ + (*int8)(nil), (*int16)(nil), (*int32)(nil), (*int64)(nil), (*int)(nil), + (*uint8)(nil), (*uint16)(nil), (*uint32)(nil), (*uint64)(nil), (*uint)(nil), + (*string)(nil), (*big.Int)(nil), string(""), + }.AddVariants(mod.CustomType), + }.Run("[nil]nullable", t, marshal, unmarshal) - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{ - int8(0), int16(0), int32(0), int64(0), int(0), - uint8(0), uint16(0), uint32(0), uint64(0), uint(0), - "0", big.Int{}, - }.AddVariants(mod.CustomType), - BrokenUnmarshalTypes: brokenCustomStrings, - }.Run("[nil]unmarshal", t, nil, unmarshal) + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{ + int8(0), int16(0), int32(0), int64(0), int(0), + uint8(0), uint16(0), uint32(0), uint64(0), uint(0), + "", big.Int{}, + }.AddVariants(mod.CustomType), + }.Run("[nil]unmarshal", t, nil, unmarshal) - serialization.PositiveSet{ - Data: make([]byte, 0), - Values: mod.Values{ - int8(0), int16(0), int32(0), int64(0), int(0), - uint8(0), uint16(0), uint32(0), uint64(0), uint(0), - "0", *big.NewInt(0), - }.AddVariants(mod.All...), - BrokenUnmarshalTypes: brokenCustomStrings, - }.Run("[]unmarshal", t, nil, unmarshal) + serialization.PositiveSet{ + Data: make([]byte, 0), + Values: mod.Values{ + int8(0), int16(0), int32(0), int64(0), int(0), + uint8(0), uint16(0), uint32(0), uint64(0), uint(0), + "0", *big.NewInt(0), + }.AddVariants(mod.All...), + }.Run("[]unmarshal", t, nil, unmarshal) - serialization.PositiveSet{ - Data: []byte("\x00"), - Values: mod.Values{ - int8(0), int16(0), int32(0), int64(0), int(0), - uint8(0), uint16(0), uint32(0), uint64(0), uint(0), - "0", *big.NewInt(0), - }.AddVariants(mod.All...), - BrokenMarshalTypes: brokenMarshalTypes, - BrokenUnmarshalTypes: brokenCustomStrings, - }.Run("zeros", t, marshal, unmarshal) + serialization.PositiveSet{ + Data: []byte("\x00"), + Values: mod.Values{ + int8(0), int16(0), int32(0), int64(0), int(0), + uint8(0), uint16(0), uint32(0), uint64(0), uint(0), + "0", *big.NewInt(0), + }.AddVariants(mod.All...), + }.Run("zeros", t, marshal, unmarshal) - serialization.PositiveSet{ - Data: []byte("\x7f"), - Values: mod.Values{int8(127), int16(127), int32(127), int64(127), int(127), "127", *big.NewInt(127)}.AddVariants(mod.All...), - BrokenMarshalTypes: brokenMarshalTypes, - BrokenUnmarshalTypes: brokenCustomStrings, - }.Run("127", t, marshal, unmarshal) + serialization.PositiveSet{ + Data: []byte("\x7f"), + Values: mod.Values{int8(127), int16(127), int32(127), int64(127), int(127), "127", *big.NewInt(127)}.AddVariants(mod.All...), + }.Run("127", t, marshal, unmarshal) - serialization.PositiveSet{ - Data: []byte("\x80"), - Values: mod.Values{int8(-128), int16(-128), int32(-128), int64(-128), int(-128), "-128", *big.NewInt(-128)}.AddVariants(mod.All...), - BrokenMarshalTypes: brokenMarshalTypes, - BrokenUnmarshalTypes: brokenCustomStrings, - }.Run("-128", t, marshal, unmarshal) + serialization.PositiveSet{ + Data: []byte("\x80"), + Values: mod.Values{int8(-128), int16(-128), int32(-128), int64(-128), int(-128), "-128", *big.NewInt(-128)}.AddVariants(mod.All...), + }.Run("-128", t, marshal, unmarshal) - serialization.PositiveSet{ - Data: []byte("\xff"), - Values: mod.Values{uint8(255), uint16(255), uint32(255), uint64(255), uint(255)}.AddVariants(mod.All...), - }.Run("255", t, marshal, unmarshal) + serialization.PositiveSet{ + Data: []byte("\xff"), + Values: mod.Values{uint8(255), uint16(255), uint32(255), uint64(255), uint(255)}.AddVariants(mod.All...), + }.Run("255", t, marshal, unmarshal) + }) + } }