Preserve input type.

This commit is contained in:
KN4CK3R 2024-11-14 13:03:30 +01:00 committed by Bjørn Erik Pedersen
parent 588c9019cf
commit 23d21b0d16
2 changed files with 30 additions and 27 deletions

View file

@ -26,29 +26,32 @@ func DoArithmetic(a, b any, op rune) (any, error) {
var ai, bi int64 var ai, bi int64
var af, bf float64 var af, bf float64
var au, bu uint64 var au, bu uint64
var isInt, isFloat, isUint bool
switch av.Kind() { switch av.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
ai = av.Int() ai = av.Int()
switch bv.Kind() { switch bv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
isInt = true
bi = bv.Int() bi = bv.Int()
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
isFloat = true
af = float64(ai) // may overflow af = float64(ai) // may overflow
ai = 0
bf = bv.Float() bf = bv.Float()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
bu = bv.Uint() bu = bv.Uint()
if ai >= 0 { if ai >= 0 {
isUint = true
au = uint64(ai) au = uint64(ai)
ai = 0
} else { } else {
isInt = true
bi = int64(bu) // may overflow bi = int64(bu) // may overflow
bu = 0
} }
default: default:
return nil, errors.New("can't apply the operator to the values") return nil, errors.New("can't apply the operator to the values")
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
isFloat = true
af = av.Float() af = av.Float()
switch bv.Kind() { switch bv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@ -66,17 +69,18 @@ func DoArithmetic(a, b any, op rune) (any, error) {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
bi = bv.Int() bi = bv.Int()
if bi >= 0 { if bi >= 0 {
isUint = true
bu = uint64(bi) bu = uint64(bi)
bi = 0
} else { } else {
isInt = true
ai = int64(au) // may overflow ai = int64(au) // may overflow
au = 0
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
isFloat = true
af = float64(au) // may overflow af = float64(au) // may overflow
au = 0
bf = bv.Float() bf = bv.Float()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
isUint = true
bu = bv.Uint() bu = bv.Uint()
default: default:
return nil, errors.New("can't apply the operator to the values") return nil, errors.New("can't apply the operator to the values")
@ -94,38 +98,32 @@ func DoArithmetic(a, b any, op rune) (any, error) {
switch op { switch op {
case '+': case '+':
if ai != 0 || bi != 0 { if isInt {
return ai + bi, nil return ai + bi, nil
} else if af != 0 || bf != 0 { } else if isFloat {
return af + bf, nil return af + bf, nil
} else if au != 0 || bu != 0 { }
return au + bu, nil return au + bu, nil
}
return 0, nil
case '-': case '-':
if ai != 0 || bi != 0 { if isInt {
return ai - bi, nil return ai - bi, nil
} else if af != 0 || bf != 0 { } else if isFloat {
return af - bf, nil return af - bf, nil
} else if au != 0 || bu != 0 { }
return au - bu, nil return au - bu, nil
}
return 0, nil
case '*': case '*':
if ai != 0 || bi != 0 { if isInt {
return ai * bi, nil return ai * bi, nil
} else if af != 0 || bf != 0 { } else if isFloat {
return af * bf, nil return af * bf, nil
} else if au != 0 || bu != 0 {
return au * bu, nil
} }
return 0, nil return au * bu, nil
case '/': case '/':
if bi != 0 { if isInt && bi != 0 {
return ai / bi, nil return ai / bi, nil
} else if bf != 0 { } else if isFloat && bf != 0 {
return af / bf, nil return af / bf, nil
} else if bu != 0 { } else if isUint && bu != 0 {
return au / bu, nil return au / bu, nil
} }
return nil, errors.New("can't divide the value by 0") return nil, errors.New("can't divide the value by 0")

View file

@ -30,10 +30,12 @@ func TestDoArithmetic(t *testing.T) {
expect any expect any
}{ }{
{3, 2, '+', int64(5)}, {3, 2, '+', int64(5)},
{0, 0, '+', int64(0)},
{3, 2, '-', int64(1)}, {3, 2, '-', int64(1)},
{3, 2, '*', int64(6)}, {3, 2, '*', int64(6)},
{3, 2, '/', int64(1)}, {3, 2, '/', int64(1)},
{3.0, 2, '+', float64(5)}, {3.0, 2, '+', float64(5)},
{0.0, 0, '+', float64(0.0)},
{3.0, 2, '-', float64(1)}, {3.0, 2, '-', float64(1)},
{3.0, 2, '*', float64(6)}, {3.0, 2, '*', float64(6)},
{3.0, 2, '/', float64(1.5)}, {3.0, 2, '/', float64(1.5)},
@ -42,18 +44,22 @@ func TestDoArithmetic(t *testing.T) {
{3, 2.0, '*', float64(6)}, {3, 2.0, '*', float64(6)},
{3, 2.0, '/', float64(1.5)}, {3, 2.0, '/', float64(1.5)},
{3.0, 2.0, '+', float64(5)}, {3.0, 2.0, '+', float64(5)},
{0.0, 0.0, '+', float64(0.0)},
{3.0, 2.0, '-', float64(1)}, {3.0, 2.0, '-', float64(1)},
{3.0, 2.0, '*', float64(6)}, {3.0, 2.0, '*', float64(6)},
{3.0, 2.0, '/', float64(1.5)}, {3.0, 2.0, '/', float64(1.5)},
{uint(3), uint(2), '+', uint64(5)}, {uint(3), uint(2), '+', uint64(5)},
{uint(0), uint(0), '+', uint64(0)},
{uint(3), uint(2), '-', uint64(1)}, {uint(3), uint(2), '-', uint64(1)},
{uint(3), uint(2), '*', uint64(6)}, {uint(3), uint(2), '*', uint64(6)},
{uint(3), uint(2), '/', uint64(1)}, {uint(3), uint(2), '/', uint64(1)},
{uint(3), 2, '+', uint64(5)}, {uint(3), 2, '+', uint64(5)},
{uint(0), 0, '+', uint64(0)},
{uint(3), 2, '-', uint64(1)}, {uint(3), 2, '-', uint64(1)},
{uint(3), 2, '*', uint64(6)}, {uint(3), 2, '*', uint64(6)},
{uint(3), 2, '/', uint64(1)}, {uint(3), 2, '/', uint64(1)},
{3, uint(2), '+', uint64(5)}, {3, uint(2), '+', uint64(5)},
{0, uint(0), '+', uint64(0)},
{3, uint(2), '-', uint64(1)}, {3, uint(2), '-', uint64(1)},
{3, uint(2), '*', uint64(6)}, {3, uint(2), '*', uint64(6)},
{3, uint(2), '/', uint64(1)}, {3, uint(2), '/', uint64(1)},
@ -66,16 +72,15 @@ func TestDoArithmetic(t *testing.T) {
{-3, uint(2), '*', int64(-6)}, {-3, uint(2), '*', int64(-6)},
{-3, uint(2), '/', int64(-1)}, {-3, uint(2), '/', int64(-1)},
{uint(3), 2.0, '+', float64(5)}, {uint(3), 2.0, '+', float64(5)},
{uint(0), 0.0, '+', float64(0)},
{uint(3), 2.0, '-', float64(1)}, {uint(3), 2.0, '-', float64(1)},
{uint(3), 2.0, '*', float64(6)}, {uint(3), 2.0, '*', float64(6)},
{uint(3), 2.0, '/', float64(1.5)}, {uint(3), 2.0, '/', float64(1.5)},
{3.0, uint(2), '+', float64(5)}, {3.0, uint(2), '+', float64(5)},
{0.0, uint(0), '+', float64(0)},
{3.0, uint(2), '-', float64(1)}, {3.0, uint(2), '-', float64(1)},
{3.0, uint(2), '*', float64(6)}, {3.0, uint(2), '*', float64(6)},
{3.0, uint(2), '/', float64(1.5)}, {3.0, uint(2), '/', float64(1.5)},
{0, 0, '+', 0},
{0, 0, '-', 0},
{0, 0, '*', 0},
{"foo", "bar", '+', "foobar"}, {"foo", "bar", '+', "foobar"},
{3, 0, '/', false}, {3, 0, '/', false},
{3.0, 0, '/', false}, {3.0, 0, '/', false},