-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbfloat16_arithmetic.go
More file actions
211 lines (179 loc) · 5.76 KB
/
bfloat16_arithmetic.go
File metadata and controls
211 lines (179 loc) · 5.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
package float16
import "math"
// BFloat16AddWithMode performs addition with specified arithmetic and rounding modes.
func BFloat16AddWithMode(a, b BFloat16, mode ArithmeticMode, rounding RoundingMode) (BFloat16, error) {
// Handle NaN propagation: if either operand is NaN, propagate it
if a.IsNaN() || b.IsNaN() {
if mode == ModeExactArithmetic {
return 0, &BFloat16Error{Op: "bfloat16_add", Msg: "NaN operand in exact mode", Code: ErrNaN}
}
return BFloat16QuietNaN, nil
}
// Handle zeros
if a.IsZero() {
return b, nil
}
if b.IsZero() {
return a, nil
}
// Handle infinity cases
if a.IsInf(0) || b.IsInf(0) {
if a.IsInf(1) && b.IsInf(-1) || a.IsInf(-1) && b.IsInf(1) {
if mode == ModeExactArithmetic {
return 0, &BFloat16Error{Op: "bfloat16_add", Msg: "infinity - infinity is undefined", Code: ErrInvalidOperation}
}
return BFloat16QuietNaN, nil
}
if a.IsInf(0) {
return a, nil
}
return b, nil
}
if mode == ModeFastArithmetic {
return BFloat16FromFloat32(a.ToFloat32() + b.ToFloat32()), nil
}
// IEEE mode: compute in float32 with specified rounding, handle gradual underflow
result := a.ToFloat32() + b.ToFloat32()
bf := BFloat16FromFloat32WithRounding(result, rounding)
// Gradual underflow: if the float32 result is non-zero but rounds to BFloat16 zero,
// return the smallest subnormal with the correct sign instead.
if result != 0 && bf.IsZero() {
if result > 0 {
return BFloat16SmallestPosSubnormal, nil
}
return BFloat16SmallestNegSubnormal, nil
}
return bf, nil
}
// BFloat16SubWithMode performs subtraction with specified arithmetic and rounding modes.
func BFloat16SubWithMode(a, b BFloat16, mode ArithmeticMode, rounding RoundingMode) (BFloat16, error) {
return BFloat16AddWithMode(a, BFloat16Neg(b), mode, rounding)
}
// BFloat16MulWithMode performs multiplication with specified arithmetic and rounding modes.
func BFloat16MulWithMode(a, b BFloat16, mode ArithmeticMode, rounding RoundingMode) (BFloat16, error) {
// NaN propagation
if a.IsNaN() || b.IsNaN() {
if mode == ModeExactArithmetic {
return 0, &BFloat16Error{Op: "bfloat16_mul", Msg: "NaN operand in exact mode", Code: ErrNaN}
}
return BFloat16QuietNaN, nil
}
aZero := a.IsZero()
bZero := b.IsZero()
// 0 * Inf = NaN
if (aZero && b.IsInf(0)) || (a.IsInf(0) && bZero) {
if mode == ModeExactArithmetic {
return 0, &BFloat16Error{Op: "bfloat16_mul", Msg: "zero times infinity is undefined", Code: ErrInvalidOperation}
}
return BFloat16QuietNaN, nil
}
// Handle zeros
if aZero || bZero {
if a.Signbit() != b.Signbit() {
return BFloat16NegativeZero, nil
}
return BFloat16PositiveZero, nil
}
// Handle infinities
if a.IsInf(0) || b.IsInf(0) {
if a.Signbit() != b.Signbit() {
return BFloat16NegativeInfinity, nil
}
return BFloat16PositiveInfinity, nil
}
if mode == ModeFastArithmetic {
return BFloat16FromFloat32(a.ToFloat32() * b.ToFloat32()), nil
}
// IEEE mode with gradual underflow
result := a.ToFloat32() * b.ToFloat32()
bf := BFloat16FromFloat32WithRounding(result, rounding)
if result != 0 && bf.IsZero() {
if result > 0 {
return BFloat16SmallestPosSubnormal, nil
}
return BFloat16SmallestNegSubnormal, nil
}
return bf, nil
}
// BFloat16DivWithMode performs division with specified arithmetic and rounding modes.
func BFloat16DivWithMode(a, b BFloat16, mode ArithmeticMode, rounding RoundingMode) (BFloat16, error) {
// NaN propagation
if a.IsNaN() || b.IsNaN() {
if mode == ModeExactArithmetic {
return 0, &BFloat16Error{Op: "bfloat16_div", Msg: "NaN operand in exact mode", Code: ErrNaN}
}
return BFloat16QuietNaN, nil
}
// 0 / 0 = NaN
if a.IsZero() && b.IsZero() {
if mode == ModeExactArithmetic {
return 0, &BFloat16Error{Op: "bfloat16_div", Msg: "zero divided by zero is undefined", Code: ErrInvalidOperation}
}
return BFloat16QuietNaN, nil
}
// finite / 0 = +/-Inf
if b.IsZero() {
if mode == ModeExactArithmetic {
return 0, &BFloat16Error{Op: "bfloat16_div", Msg: "division by zero", Code: ErrDivisionByZero}
}
if a.Signbit() != b.Signbit() {
return BFloat16NegativeInfinity, nil
}
return BFloat16PositiveInfinity, nil
}
// 0 / finite = +/-0
if a.IsZero() {
if a.Signbit() != b.Signbit() {
return BFloat16NegativeZero, nil
}
return BFloat16PositiveZero, nil
}
// Inf / Inf = NaN
if a.IsInf(0) && b.IsInf(0) {
if mode == ModeExactArithmetic {
return 0, &BFloat16Error{Op: "bfloat16_div", Msg: "infinity divided by infinity is undefined", Code: ErrInvalidOperation}
}
return BFloat16QuietNaN, nil
}
// Inf / finite = +/-Inf
if a.IsInf(0) {
if a.Signbit() != b.Signbit() {
return BFloat16NegativeInfinity, nil
}
return BFloat16PositiveInfinity, nil
}
// finite / Inf = +/-0
if b.IsInf(0) {
if a.Signbit() != b.Signbit() {
return BFloat16NegativeZero, nil
}
return BFloat16PositiveZero, nil
}
if mode == ModeFastArithmetic {
return BFloat16FromFloat32(a.ToFloat32() / b.ToFloat32()), nil
}
// IEEE mode with gradual underflow
result := a.ToFloat32() / b.ToFloat32()
bf := BFloat16FromFloat32WithRounding(result, rounding)
if result != 0 && bf.IsZero() {
if result > 0 {
return BFloat16SmallestPosSubnormal, nil
}
return BFloat16SmallestNegSubnormal, nil
}
return bf, nil
}
// BFloat16FMA computes a fused multiply-add (a*b + c) for BFloat16 values.
// This is a stub that returns an error; a full implementation is planned for a future phase.
func BFloat16FMA(a, b, c BFloat16) (BFloat16, error) {
// NaN propagation
if a.IsNaN() || b.IsNaN() || c.IsNaN() {
return BFloat16QuietNaN, nil
}
// Use float64 FMA for intermediate precision, then round back to BFloat16
fa := float64(a.ToFloat32())
fb := float64(b.ToFloat32())
fc := float64(c.ToFloat32())
result := math.FMA(fa, fb, fc)
return BFloat16FromFloat32(float32(result)), nil
}