diff --git a/serialization/smallint/marshal_utils.go b/serialization/smallint/marshal_utils.go index baec06bb4..6994168f6 100644 --- a/serialization/smallint/marshal_utils.go +++ b/serialization/smallint/marshal_utils.go @@ -14,7 +14,10 @@ var ( ) func EncInt8(v int8) ([]byte, error) { - return encInt16(int16(v)), nil + if v < 0 { + return []byte{255, byte(v)}, nil + } + return []byte{0, byte(v)}, nil } func EncInt8R(v *int8) ([]byte, error) { @@ -53,7 +56,7 @@ func EncInt64(v int64) ([]byte, error) { if v > math.MaxInt16 || v < math.MinInt16 { return nil, fmt.Errorf("failed to marshal smallint: value %#v out of range", v) } - return []byte{byte(v >> 8), byte(v)}, nil + return encInt64(v), nil } func EncInt64R(v *int64) ([]byte, error) { @@ -117,7 +120,7 @@ func EncUint64(v uint64) ([]byte, error) { if v > math.MaxUint16 { return nil, fmt.Errorf("failed to marshal smallint: value %#v out of range", v) } - return []byte{byte(v >> 8), byte(v)}, nil + return encUint64(v), nil } func EncUint64R(v *uint64) ([]byte, error) { @@ -145,15 +148,17 @@ func EncBigInt(v big.Int) ([]byte, error) { if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 { return nil, fmt.Errorf("failed to marshal smallint: value (%T)(%s) out of range", v, v.String()) } - iv := v.Int64() - return []byte{byte(iv >> 8), byte(iv)}, nil + return encInt64(v.Int64()), nil } func EncBigIntR(v *big.Int) ([]byte, error) { if v == nil { return nil, nil } - return EncBigInt(*v) + if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 { + return nil, fmt.Errorf("failed to marshal smallint: value (%T)(%s) out of range", v, v.String()) + } + return encInt64(v.Int64()), nil } func EncString(v string) ([]byte, error) { @@ -165,7 +170,7 @@ func EncString(v string) ([]byte, error) { if err != nil { return nil, fmt.Errorf("failed to marshal smallint: can not marshal (%T)(%[1]v) %s", v, err) } - return []byte{byte(n >> 8), byte(n)}, nil + return encInt64(n), nil } func EncStringR(v *string) ([]byte, error) { @@ -177,12 +182,41 @@ func EncStringR(v *string) ([]byte, error) { func EncReflect(v reflect.Value) ([]byte, error) { switch v.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - return EncInt64(v.Int()) - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - return EncUint64(v.Uint()) + case reflect.Int8: + val := v.Int() + if val < 0 { + return []byte{255, byte(val)}, nil + } + return []byte{0, byte(val)}, nil + case reflect.Int16: + return encInt64(v.Int()), nil + case reflect.Int, reflect.Int64, reflect.Int32: + val := v.Int() + if val > math.MaxInt16 || val < math.MinInt16 { + return nil, fmt.Errorf("failed to marshal smallint: custom type value (%T)(%[1]v) out of range", v.Interface()) + } + return encInt64(val), nil + case reflect.Uint8: + return []byte{0, byte(v.Uint())}, nil + case reflect.Uint16: + return encUint64(v.Uint()), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32: + val := v.Uint() + if val > math.MaxUint16 { + return nil, fmt.Errorf("failed to marshal smallint: custom type value (%T)(%[1]v) out of range", v.Interface()) + } + return encUint64(val), nil case reflect.String: - return EncString(v.String()) + val := v.String() + if val == "" { + return nil, nil + } + + n, err := strconv.ParseInt(val, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to marshal smallint: can not marshal (%T)(%[1]v), %s", v.Interface(), err) + } + return encInt64(n), nil default: return nil, fmt.Errorf("failed to marshal smallint: unsupported value type (%T)(%[1]v)", v.Interface()) } @@ -198,3 +232,11 @@ func EncReflectR(v reflect.Value) ([]byte, error) { func encInt16(v int16) []byte { return []byte{byte(v >> 8), byte(v)} } + +func encInt64(v int64) []byte { + return []byte{byte(v >> 8), byte(v)} +} + +func encUint64(v uint64) []byte { + return []byte{byte(v >> 8), byte(v)} +} diff --git a/serialization/smallint/unmarshal_utils.go b/serialization/smallint/unmarshal_utils.go index bbea2d91d..bac337e98 100644 --- a/serialization/smallint/unmarshal_utils.go +++ b/serialization/smallint/unmarshal_utils.go @@ -37,11 +37,23 @@ func DecInt8R(p []byte, v **int8) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(int8) - return DecInt8(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(int8) + } + case 2: + val := decInt16(p) + if val > math.MaxInt8 || val < math.MinInt8 { + return fmt.Errorf("failed to unmarshal smallint: to unmarshal into int8, the data should be in the int8 range") + } + tmp := int8(val) + *v = &tmp + default: + return errWrongDataLen } - *v = nil return nil } @@ -64,11 +76,19 @@ func DecInt16R(p []byte, v **int16) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(int16) - return DecInt16(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(int16) + } + case 2: + val := decInt16(p) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -80,7 +100,7 @@ func DecInt32(p []byte, v *int32) error { case 0: *v = 0 case 2: - *v = int32(decInt16(p)) + *v = decInt32(p) default: return errWrongDataLen } @@ -91,11 +111,19 @@ func DecInt32R(p []byte, v **int32) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(int32) - return DecInt32(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(int32) + } + case 2: + val := decInt32(p) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -107,7 +135,7 @@ func DecInt64(p []byte, v *int64) error { case 0: *v = 0 case 2: - *v = int64(decInt16(p)) + *v = decInt64(p) default: return errWrongDataLen } @@ -118,11 +146,19 @@ func DecInt64R(p []byte, v **int64) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(int64) - return DecInt64(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(int64) + } + case 2: + val := decInt64(p) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -134,7 +170,7 @@ func DecInt(p []byte, v *int) error { case 0: *v = 0 case 2: - *v = int(decInt16(p)) + *v = decInt(p) default: return errWrongDataLen } @@ -145,11 +181,19 @@ func DecIntR(p []byte, v **int) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(int) - return DecInt(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(int) + } + case 2: + val := decInt(p) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -175,11 +219,22 @@ func DecUint8R(p []byte, v **uint8) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(uint8) - return DecUint8(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(uint8) + } + case 2: + if p[0] != 0 { + return fmt.Errorf("failed to unmarshal smallint: to unmarshal into uint8, the data should be in the uint8 range") + } + val := p[1] + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -191,7 +246,7 @@ func DecUint16(p []byte, v *uint16) error { case 0: *v = 0 case 2: - *v = uint16(p[0])<<8 | uint16(p[1]) + *v = decUint16(p) default: return errWrongDataLen } @@ -202,11 +257,19 @@ func DecUint16R(p []byte, v **uint16) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(uint16) - return DecUint16(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(uint16) + } + case 2: + val := decUint16(p) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -218,7 +281,7 @@ func DecUint32(p []byte, v *uint32) error { case 0: *v = 0 case 2: - *v = uint32(p[0])<<8 | uint32(p[1]) + *v = decUint32(p) default: return errWrongDataLen } @@ -229,11 +292,19 @@ func DecUint32R(p []byte, v **uint32) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(uint32) - return DecUint32(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(uint32) + } + case 2: + val := decUint32(p) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -256,11 +327,19 @@ func DecUint64R(p []byte, v **uint64) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(uint64) - return DecUint64(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(uint64) + } + case 2: + val := decUint64(p) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -272,7 +351,7 @@ func DecUint(p []byte, v *uint) error { case 0: *v = 0 case 2: - *v = uint(p[0])<<8 | uint(p[1]) + *v = decUint(p) default: return errWrongDataLen } @@ -283,11 +362,19 @@ func DecUintR(p []byte, v **uint) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(uint) - return DecUint(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(uint) + } + case 2: + val := decUint(p) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -297,13 +384,13 @@ func DecString(p []byte, v *string) error { } switch len(p) { case 0: - if p != nil { - *v = "0" - } else { + if p == nil { *v = "" + } else { + *v = "0" } case 2: - *v = strconv.FormatInt(int64(decInt16(p)), 10) + *v = strconv.FormatInt(decInt64(p), 10) default: return errWrongDataLen } @@ -314,11 +401,20 @@ func DecStringR(p []byte, v **string) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = new(string) - return DecString(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + val := "0" + *v = &val + } + case 2: + val := strconv.FormatInt(decInt64(p), 10) + *v = &val + default: + return errWrongDataLen } - *v = nil return nil } @@ -330,7 +426,7 @@ func DecBigInt(p []byte, v *big.Int) error { case 0: v.SetInt64(0) case 2: - v.SetInt64(int64(decInt16(p))) + v.SetInt64(decInt64(p)) default: return errWrongDataLen } @@ -341,11 +437,18 @@ func DecBigIntR(p []byte, v **big.Int) error { if v == nil { return errNilReference(v) } - if p != nil { - *v = big.NewInt(0) - return DecBigInt(p, *v) + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = big.NewInt(0) + } + case 2: + *v = big.NewInt(decInt64(p)) + default: + return errWrongDataLen } - *v = nil return nil } @@ -357,10 +460,10 @@ func DecReflect(p []byte, v reflect.Value) error { switch v = v.Elem(); v.Kind() { case reflect.Int8: return decReflectInt8(p, v) - case reflect.Uint8: - return decReflectUint8(p, v) case reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: return decReflectInts(p, v) + case reflect.Uint8: + return decReflectUint8(p, v) case reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: return decReflectUints(p, v) case reflect.String: @@ -371,14 +474,24 @@ func DecReflect(p []byte, v reflect.Value) error { } func DecReflectR(p []byte, v reflect.Value) error { - if p != nil { - zeroValue := reflect.New(v.Type().Elem().Elem()) - v.Elem().Set(zeroValue) - return DecReflect(p, v.Elem()) + if v.IsNil() { + return fmt.Errorf("failed to unmarshal tinyint: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface()) + } + + switch v.Type().Elem().Elem().Kind() { + case reflect.Int8: + return decReflectInt8R(p, v) + case reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + return decReflectIntsR(p, v) + case reflect.Uint8: + return decReflectUint8R(p, v) + case reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return decReflectUintsR(p, v) + case reflect.String: + return decReflectStringR(p, v) + default: + return fmt.Errorf("failed to unmarshal tinyint: unsupported value type (%T)(%[1]v)", v.Interface()) } - nilValue := reflect.Zero(v.Elem().Type()) - v.Elem().Set(nilValue) - return nil } func decReflectInt8(p []byte, v reflect.Value) error { @@ -386,11 +499,11 @@ func decReflectInt8(p []byte, v reflect.Value) error { case 0: v.SetInt(0) case 2: - val := decInt16(p) + val := decInt64(p) if val > math.MaxInt8 || val < math.MinInt8 { - return fmt.Errorf("failed to unmarshal smallint: to unmarshal into %T, the data should be in the int8 range", v.Interface()) + return fmt.Errorf("failed to unmarshal smallint: to unmarshal into (%T), the data should be in the int8 range", v.Interface()) } - v.SetInt(int64(val)) + v.SetInt(val) default: return errWrongDataLen } @@ -402,7 +515,7 @@ func decReflectInts(p []byte, v reflect.Value) error { case 0: v.SetInt(0) case 2: - v.SetInt(int64(decInt16(p))) + v.SetInt(decInt64(p)) default: return errWrongDataLen } @@ -415,7 +528,7 @@ func decReflectUint8(p []byte, v reflect.Value) error { v.SetUint(0) case 2: if p[0] != 0 { - return fmt.Errorf("failed to unmarshal smallint: to unmarshal into %T, the data should be in the uint8 range", v.Interface()) + return fmt.Errorf("failed to unmarshal smallint: to unmarshal into (%T), the data should be in the uint8 range", v.Interface()) } v.SetUint(uint64(p[1])) default: @@ -445,7 +558,98 @@ func decReflectString(p []byte, v reflect.Value) error { v.SetString("") } case 2: - v.SetString(strconv.FormatInt(int64(decInt16(p)), 10)) + v.SetString(strconv.FormatInt(decInt64(p), 10)) + default: + return errWrongDataLen + } + return nil +} + +func decReflectNullableR(p []byte, v reflect.Value) reflect.Value { + if p == nil { + return reflect.Zero(v.Elem().Type()) + } + return reflect.New(v.Type().Elem().Elem()) +} + +func decReflectInt8R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 2: + val := decInt64(p) + if val > math.MaxInt8 || val < math.MinInt8 { + return fmt.Errorf("failed to unmarshal smallint: to unmarshal into (%T), the data should be in the int8 range", v.Interface()) + } + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetInt(val) + v.Elem().Set(newVal) + default: + return errWrongDataLen + } + return nil +} + +func decReflectIntsR(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 2: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetInt(decInt64(p)) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decReflectUint8R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 2: + if p[0] != 0 { + return fmt.Errorf("failed to unmarshal smallint: to unmarshal into (%T), the data should be in the uint8 range", v.Interface()) + } + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetUint(uint64(p[1])) + v.Elem().Set(newVal) + default: + return errWrongDataLen + } + return nil +} + +func decReflectUintsR(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 2: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetUint(decUint64(p)) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decReflectStringR(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + var val reflect.Value + if p == nil { + val = reflect.Zero(v.Type().Elem()) + } else { + val = reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString("0") + } + v.Elem().Set(val) + case 2: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString(strconv.FormatInt(decInt64(p), 10)) + v.Elem().Set(val) default: return errWrongDataLen } @@ -456,6 +660,39 @@ func decInt16(p []byte) int16 { return int16(p[0])<<8 | int16(p[1]) } +func decInt32(p []byte) int32 { + if p[0] > math.MaxInt8 { + return -65536 + int32(p[0])<<8 | int32(p[1]) + } + return int32(p[0])<<8 | int32(p[1]) +} + +func decInt64(p []byte) int64 { + if p[0] > math.MaxInt8 { + return -65536 + int64(p[0])<<8 | int64(p[1]) + } + return int64(p[0])<<8 | int64(p[1]) +} + +func decInt(p []byte) int { + if p[0] > math.MaxInt8 { + return -65536 + int(p[0])<<8 | int(p[1]) + } + return int(p[0])<<8 | int(p[1]) +} + +func decUint16(p []byte) uint16 { + return uint16(p[0])<<8 | uint16(p[1]) +} + +func decUint32(p []byte) uint32 { + return uint32(p[0])<<8 | uint32(p[1]) +} + func decUint64(p []byte) uint64 { return uint64(p[0])<<8 | uint64(p[1]) } + +func decUint(p []byte) uint { + return uint(p[0])<<8 | uint(p[1]) +}