Skip to content

Commit a136bd3

Browse files
committed
add: Expand module
1 parent d55d34f commit a136bd3

5 files changed

Lines changed: 290 additions & 2 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,30 @@ module ClArray =
3333

3434
outputArray
3535

36+
let assignManyInit (clContext: ClContext) workGroupSize (initializer: Expr<int -> 'a>) =
37+
38+
let init =
39+
<@ fun (range: Range1D) indicesLength (indices: ClArray<int>) (outputBuffer: ClArray<'a>) ->
40+
41+
let gid = range.GlobalID0
42+
43+
if gid < indicesLength then
44+
let targetIndex = indices.[gid]
45+
46+
outputBuffer.[targetIndex] <- (%initializer) gid @>
47+
48+
let program = clContext.Compile(init)
49+
50+
fun (processor: MailboxProcessor<_>) (indices: ClArray<int>) (result: ClArray<'a>) ->
51+
52+
let kernel = program.GetKernel()
53+
54+
let ndRange =
55+
Range1D.CreateValid(indices.Length, workGroupSize)
56+
57+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange indices.Length indices result))
58+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
59+
3660
let create (clContext: ClContext) workGroupSize =
3761

3862
let create =
@@ -62,7 +86,7 @@ module ClArray =
6286

6387
outputArray
6488

65-
let zeroCreate (clContext: ClContext) workGroupSize =
89+
let zeroCreate<'a> (clContext: ClContext) workGroupSize =
6690

6791
let create = create clContext workGroupSize
6892

src/GraphBLAS-sharp.Backend/Common/Gather.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace GraphBLAS.FSharp.Backend.Common.Gather
1+
namespace GraphBLAS.FSharp.Backend.Common
22

33
open Brahma.FSharp
44

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
<Compile Include="Matrix/CSRMatrix/Map2.fs" />
4141
<Compile Include="Matrix/CSRMatrix/SpGEMM.fs" />
4242
<Compile Include="Matrix/CSRMatrix/CSRMatrix.fs" />
43+
<Compile Include="Matrix\CSRMatrix\SpGEMM\Expand.fs" />
4344
<Compile Include="Matrix/Matrix.fs" />
4445
<Compile Include="Vector/SparseVector/Map2.fs" />
4546
<Compile Include="Vector/SparseVector/SparseVector.fs" />
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
namespace GraphBLAS.FSharp.Backend.Matrix.CSRMatrix.SpGEMM
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Common
5+
open GraphBLAS.FSharp.Backend.Predefined
6+
open GraphBLAS.FSharp.Backend.Objects.ClContext
7+
open GraphBLAS.FSharp.Backend.Objects
8+
open GraphBLAS.FSharp.Backend.Objects.ClCell
9+
10+
type Indices = ClArray<int>
11+
12+
type Values<'a> = ClArray<'a>
13+
14+
module Expand =
15+
/// <summary>
16+
/// Get the number of non-zero elements for each row of the right matrix for non-zero item in left matrix.
17+
/// </summary>
18+
let requiredRawsLengths =
19+
<@ fun gid (leftMatrixColumnsIndices: Indices) (rightMatrixRawPointers: Indices) ->
20+
let columnIndex = leftMatrixColumnsIndices.[gid]
21+
let startRawIndex = rightMatrixRawPointers.[columnIndex]
22+
let exclusiveRawEndIndex = rightMatrixRawPointers.[columnIndex + 1]
23+
24+
exclusiveRawEndIndex - startRawIndex @>
25+
26+
/// <summary>
27+
/// Get the pointer to right matrix raw for each non-zero in left matrix.
28+
/// </summary>
29+
let requiredRawPointers =
30+
<@ fun gid (leftMatrixColumnsIndices: Indices) (rightMatrixRawPointers: Indices) ->
31+
let columnIndex = leftMatrixColumnsIndices.[gid]
32+
let startRawIndex = rightMatrixRawPointers.[columnIndex]
33+
34+
startRawIndex @>
35+
36+
let processLeftMatrixColumnsAndRightMatrixRawPointers (clContext: ClContext) workGroupSize writeOperation =
37+
38+
let kernel =
39+
<@ fun (ndRange: Range1D) columnsLength (leftMatrixColumnsIndices: Indices) (rightMatrixRawPointers: Indices) (result: Indices) ->
40+
41+
let gid = ndRange.GlobalID0
42+
43+
if gid < columnsLength then
44+
result.[gid] <- (%writeOperation) gid leftMatrixColumnsIndices rightMatrixRawPointers @>
45+
46+
let kernel = clContext.Compile kernel
47+
48+
fun (processor: MailboxProcessor<_>) (leftMatrixColumnsIndices: Indices) (rightMatrixRawPointers: Indices) ->
49+
let resultLength = leftMatrixColumnsIndices.Length
50+
51+
let requiredRawsLengths =
52+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
53+
54+
let kernel = kernel.GetKernel()
55+
56+
let ndRange =
57+
Range1D.CreateValid(resultLength, workGroupSize)
58+
59+
processor.Post(
60+
Msg.MsgSetArguments
61+
(fun () ->
62+
kernel.KernelFunc
63+
ndRange
64+
resultLength
65+
leftMatrixColumnsIndices
66+
rightMatrixRawPointers
67+
requiredRawsLengths)
68+
)
69+
70+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
71+
72+
requiredRawsLengths
73+
74+
let getGlobalPositions (clContext: ClContext) workGroupSize =
75+
76+
let zeroCreate = ClArray.zeroCreate<int> clContext workGroupSize
77+
78+
let assignUnits = ClArray.assignManyInit clContext workGroupSize <@ fun _ -> 1 @>
79+
80+
let prefixSum = PrefixSum.standardIncludeInplace clContext workGroupSize
81+
82+
fun (processor: MailboxProcessor<_>) resultLength (globalRightMatrixValuesPositions: Indices) ->
83+
84+
/// We get an array of zeros
85+
let globalPositions = zeroCreate processor DeviceOnly resultLength
86+
87+
// Insert units at the beginning of new lines (source positions)
88+
assignUnits processor globalRightMatrixValuesPositions globalPositions
89+
90+
// Apply the prefix sum,
91+
// get an array where different sub-arrays of pointers to elements of the same row differ in values
92+
(prefixSum processor globalPositions).Free processor
93+
94+
globalPositions
95+
96+
let getRightMatrixPointers (clContext: ClContext) workGroupSize =
97+
98+
let kernel =
99+
<@ fun (ndRange: Range1D) length (globalRightMatrixValuesPositions: Indices) (requiredRightMatrixValuesPointers: Indices) (globalPositions: Indices) (result: Indices) ->
100+
101+
let gid = ndRange.GlobalID0
102+
103+
if gid < length then
104+
// index corresponding to the position of pointers
105+
let positionIndex = globalPositions.[gid] - 1
106+
107+
// the position of the beginning of a new line of pointers
108+
let sourcePosition = globalRightMatrixValuesPositions.[positionIndex]
109+
110+
// offset from the source pointer
111+
let offsetFromSourcePosition = gid - sourcePosition
112+
113+
// pointer to the first element in the row of the right matrix from which
114+
// the offset will be counted to get pointers to subsequent elements in this row
115+
let sourcePointer = requiredRightMatrixValuesPointers.[positionIndex]
116+
117+
// adding up the mix with the source pointer,
118+
// we get a pointer to a specific element in the raw
119+
result.[gid] <- sourcePointer + offsetFromSourcePosition @>
120+
121+
let kernel = clContext.Compile kernel
122+
123+
fun (processor: MailboxProcessor<_>) (resultLength: int) (globalRightMatrixValuesPositions: Indices) (requiredRightMatrixValuesPointers: Indices) (globalPositions: Indices) ->
124+
125+
let globalRightMatrixValuesPointers =
126+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
127+
128+
let kernel = kernel.GetKernel()
129+
130+
let ndRange =
131+
Range1D.CreateValid(resultLength, workGroupSize)
132+
133+
processor.Post(
134+
Msg.MsgSetArguments
135+
(fun () ->
136+
kernel.KernelFunc
137+
ndRange
138+
resultLength
139+
globalRightMatrixValuesPositions
140+
requiredRightMatrixValuesPointers
141+
globalPositions
142+
globalRightMatrixValuesPointers)
143+
)
144+
145+
processor.Post <| Msg.CreateRunMsg<_, _> kernel
146+
processor.Post <| Msg.CreateFreeMsg globalPositions
147+
148+
globalRightMatrixValuesPointers
149+
150+
let getLeftMatrixValuesCorrespondinglyToPositionsPattern<'a> (clContext: ClContext) workGroupSize =
151+
152+
let kernel =
153+
<@ fun (ndRange: Range1D) globalLength (globalPositions: Indices) (rightMatrixValues: ClArray<'a>) (result: ClArray<'a>) ->
154+
155+
let gid = ndRange.GlobalID0
156+
157+
if gid < globalLength then
158+
let valuePosition = globalPositions.[gid] - 1
159+
160+
result.[gid] <- rightMatrixValues.[valuePosition]@>
161+
162+
let kernel = clContext.Compile kernel
163+
164+
fun (processor: MailboxProcessor<_>) (globalLength: int) (globalPositions: Indices) (rightMatrixValues: Values<'a>)->
165+
166+
// globalLength == globalPositions.Length
167+
let resultLeftMatrixValues =
168+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalLength)
169+
170+
let kernel = kernel.GetKernel()
171+
172+
let ndRange =
173+
Range1D.CreateValid(globalLength, workGroupSize)
174+
175+
processor.Post(
176+
Msg.MsgSetArguments
177+
(fun () ->
178+
kernel.KernelFunc
179+
ndRange
180+
globalLength
181+
globalPositions
182+
rightMatrixValues
183+
resultLeftMatrixValues)
184+
)
185+
186+
processor.Post <| Msg.CreateRunMsg<_, _> kernel
187+
processor.Post <| Msg.CreateFreeMsg globalPositions
188+
189+
resultLeftMatrixValues
190+
191+
let run (clContext: ClContext) workGroupSize multiplication =
192+
193+
let getRequiredRawsLengths =
194+
processLeftMatrixColumnsAndRightMatrixRawPointers clContext workGroupSize requiredRawsLengths
195+
196+
let prefixSumExclude =
197+
PrefixSum.standardExcludeInplace clContext workGroupSize
198+
199+
let getRequiredRightMatrixValuesPointers =
200+
processLeftMatrixColumnsAndRightMatrixRawPointers clContext workGroupSize requiredRawPointers
201+
202+
let getRightMatrixValuesPointers =
203+
getRightMatrixPointers clContext workGroupSize
204+
205+
let getGlobalPositions = getGlobalPositions clContext workGroupSize
206+
207+
let gatherRightMatrixData = Gather.run clContext workGroupSize
208+
209+
let gatherIndices = Gather.run clContext workGroupSize
210+
211+
let getLeftMatrixValues =
212+
getLeftMatrixValuesCorrespondinglyToPositionsPattern clContext workGroupSize
213+
214+
let map2 = ClArray.map2 clContext workGroupSize multiplication
215+
216+
fun (processor: MailboxProcessor<_>) (leftMatrix: ClMatrix.CSR<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
217+
218+
let requiredRawsLengths =
219+
getRequiredRawsLengths processor leftMatrix.Columns rightMatrix.RowPointers
220+
221+
// global expanded array length
222+
let globalLength =
223+
(prefixSumExclude processor requiredRawsLengths).ToHostAndFree processor
224+
225+
// since prefix sum include
226+
// positions in global array for right matrix
227+
let globalRightMatrixValuesRawsStartPositions = requiredRawsLengths
228+
229+
// pointers to required raws in right matrix values
230+
let requiredRightMatrixValuesPointers =
231+
getRequiredRightMatrixValuesPointers processor leftMatrix.Columns rightMatrix.RowPointers
232+
233+
// bitmap to distinguish different raws in a general array
234+
let globalPositions =
235+
getGlobalPositions processor globalLength globalRightMatrixValuesRawsStartPositions
236+
237+
// extended pointers to all required right matrix numbers
238+
let globalRightMatrixValuesPointers =
239+
getRightMatrixValuesPointers processor globalLength globalPositions globalRightMatrixValuesRawsStartPositions requiredRightMatrixValuesPointers
240+
241+
// gather all required right matrix values
242+
let extendedRightMatrixValues =
243+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalLength)
244+
245+
gatherRightMatrixData processor globalRightMatrixValuesPointers rightMatrix.Values extendedRightMatrixValues
246+
247+
// gather all required right matrix column indices
248+
let extendedRightMatrixColumns =
249+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalLength)
250+
251+
gatherIndices processor globalRightMatrixValuesPointers rightMatrix.Columns extendedRightMatrixColumns
252+
253+
// left matrix values correspondingly to right matrix values
254+
let extendedLeftMatrixValues =
255+
getLeftMatrixValues processor globalLength globalPositions rightMatrix.Values
256+
257+
let multiplicationResult =
258+
map2 processor DeviceOnly extendedLeftMatrixValues extendedRightMatrixValues
259+
260+
multiplicationResult, extendedRightMatrixColumns

src/GraphBLAS-sharp.Backend/Objects/ClCell.fs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ module ClCell =
1111
processor.Post(Msg.CreateFreeMsg<_>(this))
1212

1313
res.[0]
14+
15+
member this.Free(processor: MailboxProcessor<_>) =
16+
processor.Post(Msg.CreateFreeMsg<_>(this))

0 commit comments

Comments
 (0)