-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsoftmax.cu
More file actions
206 lines (169 loc) · 6.59 KB
/
softmax.cu
File metadata and controls
206 lines (169 loc) · 6.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#include <iostream>
#include <cuda_runtime.h>
#include <cmath>
#include <cassert>
// Kernel to find block maxes (first step of softmax)
__global__ void findBlockMax(float *input, float *blockMax, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
// Shared memory for reduction
__shared__ float s_max[256];
// Load value into register and shared memory
float val = (idx < n) ? input[idx] : -INFINITY;
s_max[tid] = val;
__syncthreads();
// Fast recursive reduction for max (using registers)
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
float other = s_max[tid + stride];
val = fmaxf(val, other); // Keep max in register
s_max[tid] = val;
}
__syncthreads();
}
// Store block max
if (tid == 0) {
blockMax[blockIdx.x] = s_max[0];
}
}
// Kernel to compute exp(x - globalMax) and find block sums
__global__ void computeExpAndBlockSum(float *input, float *expValues, float *blockSum, float globalMax, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
// Shared memory for reduction
__shared__ float s_sum[256];
// Step 1: Compute exp(x - globalMax) using register
float expVal = (idx < n) ? expf(input[idx] - globalMax) : 0.0f;
s_sum[tid] = expVal;
__syncthreads();
// Store exp value
if (idx < n) {
expValues[idx] = expVal;
}
// Step 2: Sum all exp values within block (fast recursive reduction)
float sumReg = expVal; // Keep in register
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
float other = s_sum[tid + stride];
sumReg += other; // Register accumulation
s_sum[tid] = sumReg;
}
__syncthreads();
}
// Store block sum
if (tid == 0) {
blockSum[blockIdx.x] = s_sum[0];
}
}
// Kernel to find global max from block maxes
__global__ void reduceGlobalMax(float *blockMax, float *globalMax, int numBlocks) {
int tid = threadIdx.x;
__shared__ float s_data[256];
float val = (tid < numBlocks) ? blockMax[tid] : -INFINITY;
s_data[tid] = val;
__syncthreads();
// Fast recursive reduction
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride && (tid + stride) < numBlocks) {
float other = s_data[tid + stride];
val = fmaxf(val, other);
s_data[tid] = val;
}
__syncthreads();
}
if (tid == 0) {
globalMax[0] = s_data[0];
}
}
// Kernel to find global sum from block sums
__global__ void reduceGlobalSum(float *blockSum, float *globalSum, int numBlocks) {
int tid = threadIdx.x;
__shared__ float s_data[256];
float val = (tid < numBlocks) ? blockSum[tid] : 0.0f;
s_data[tid] = val;
__syncthreads();
// Fast recursive reduction
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride && (tid + stride) < numBlocks) {
float other = s_data[tid + stride];
val += other;
s_data[tid] = val;
}
__syncthreads();
}
if (tid == 0) {
globalSum[0] = s_data[0];
}
}
// Kernel to normalize exp values by global sum
__global__ void normalizeSoftmax(float *expValues, float *output, float globalSum, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// Normalize using register
float val = expValues[idx];
output[idx] = val / globalSum; // Final softmax value
}
}
int main() {
const int N = 1024; // Array size
size_t size = N * sizeof(float);
// Allocate unified memory
float *input, *output, *expValues;
cudaMallocManaged(&input, size);
cudaMallocManaged(&output, size);
cudaMallocManaged(&expValues, size);
// Initialize input with some values
for (int i = 0; i < N; i++) {
input[i] = static_cast<float>(i) / 10.0f; // Values from 0 to ~102.3
}
// Configure thread block
int threadsPerBlock = 256;
int numBlocks = (N + threadsPerBlock - 1) / threadsPerBlock;
// Allocate memory for block reductions
float *blockMax, *blockSum, *globalMax, *globalSum;
cudaMallocManaged(&blockMax, numBlocks * sizeof(float));
cudaMallocManaged(&blockSum, numBlocks * sizeof(float));
cudaMallocManaged(&globalMax, sizeof(float));
cudaMallocManaged(&globalSum, sizeof(float));
std::cout << "Computing softmax for " << N << " elements..." << std::endl;
std::cout << "Using " << numBlocks << " blocks with " << threadsPerBlock << " threads each" << std::endl;
std::cout << "Total threads: " << numBlocks * threadsPerBlock << std::endl;
// Step 1: Find max within each block (using fast recursion with registers)
findBlockMax<<<numBlocks, threadsPerBlock>>>(input, blockMax, N);
cudaDeviceSynchronize();
// Step 2: Find global max from block maxes (using fast recursion)
reduceGlobalMax<<<1, threadsPerBlock>>>(blockMax, globalMax, numBlocks);
cudaDeviceSynchronize();
// Step 3: Compute exp(x - globalMax) and find block sums (using registers)
computeExpAndBlockSum<<<numBlocks, threadsPerBlock>>>(input, expValues, blockSum, globalMax[0], N);
cudaDeviceSynchronize();
// Step 4: Find global sum from block sums (using fast recursion)
reduceGlobalSum<<<1, threadsPerBlock>>>(blockSum, globalSum, numBlocks);
cudaDeviceSynchronize();
// Step 5: Normalize by global sum (using registers)
normalizeSoftmax<<<numBlocks, threadsPerBlock>>>(expValues, output, globalSum[0], N);
cudaDeviceSynchronize();
// Verify results
std::cout << "\nFirst 10 softmax values:" << std::endl;
for (int i = 0; i < 10; i++) {
std::cout << "output[" << i << "] = " << output[i] << std::endl;
}
// Sum all values (should be ~1.0)
float totalSum = 0.0f;
for (int i = 0; i < N; i++) {
totalSum += output[i];
}
std::cout << "\nGlobal max: " << globalMax[0] << std::endl;
std::cout << "Global sum: " << globalSum[0] << std::endl;
std::cout << "Sum of all softmax values: " << totalSum << " (should be ~1.0)" << std::endl;
std::cout << "Verification: " << (fabs(totalSum - 1.0f) < 0.01f ? "PASSED" : "FAILED") << std::endl;
// Cleanup
cudaFree(input);
cudaFree(output);
cudaFree(expValues);
cudaFree(blockMax);
cudaFree(blockSum);
cudaFree(globalMax);
cudaFree(globalSum);
return 0;
}