Skip to content

Commit 97281a6

Browse files
committed
add: spgemm coo->csr->coo
1 parent 00bdd5a commit 97281a6

4 files changed

Lines changed: 390 additions & 13 deletions

File tree

src/GraphBLAS-sharp.Backend/Algorithms/MSBFS.fs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ module internal MSBFS =
4545

4646
let findIntersection = Intersect.findKeysIntersection clContext workGroupSize
4747

48-
fun (queue: MailboxProcessor<_>) allocationMode (levels: ClMatrix.COO<_>) (front: ClMatrix.COO<_>) ->
48+
fun (queue: MailboxProcessor<_>) allocationMode (front: ClMatrix.COO<_>) (levels: ClMatrix.COO<_>) ->
4949

5050
// Find intersection of levels and front indices.
5151
let intersection = findIntersection queue DeviceOnly front levels
@@ -64,15 +64,15 @@ module internal MSBFS =
6464

6565
let run<'a when 'a: struct>
6666
(add: Expr<int -> int -> int option>)
67-
(mul: Expr<'a -> int -> int option>)
67+
(mul: Expr<int -> 'a -> int option>)
6868
(clContext: ClContext)
6969
workGroupSize
7070
=
7171

7272
let spGeMM =
73-
Operations.SpGeMM.expand add mul clContext workGroupSize
73+
Operations.SpGeMM.COO.expand add mul clContext workGroupSize
7474

75-
let toCSRInPlace = Matrix.toCSRInPlace clContext workGroupSize
75+
let copy = Matrix.copy clContext workGroupSize
7676

7777
let updateFrontAndLevels = updateFrontAndLevels clContext workGroupSize
7878

@@ -88,10 +88,7 @@ module internal MSBFS =
8888
startMatrix
8989
|> Matrix.ofList clContext DeviceOnly sourceVertexCount vertexCount
9090

91-
let mutable front =
92-
startMatrix
93-
|> Matrix.ofList clContext DeviceOnly sourceVertexCount vertexCount
94-
|> toCSRInPlace queue DeviceOnly
91+
let mutable front = copy queue DeviceOnly levels
9592

9693
let mutable level = 0
9794
let mutable stop = false
@@ -100,16 +97,16 @@ module internal MSBFS =
10097
level <- level + 1
10198

10299
//Getting new frontier
103-
match spGeMM queue DeviceOnly matrix (ClMatrix.CSR front) with
100+
match spGeMM queue DeviceOnly (ClMatrix.COO front) matrix with
104101
| None ->
105102
front.Dispose queue
106103
stop <- true
107104
| Some newFrontier ->
108105
front.Dispose queue
109106
//Filtering visited vertices
110-
match updateFrontAndLevels queue DeviceOnly levels newFrontier with
107+
match updateFrontAndLevels queue DeviceOnly newFrontier levels with
111108
| l, Some f ->
112-
front <- toCSRInPlace queue DeviceOnly f
109+
front <- f
113110
levels.Dispose queue
114111
levels <- l
115112
newFrontier.Dispose queue

src/GraphBLAS-sharp.Backend/Matrix/COO/Matrix.fs

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,25 @@ open GraphBLAS.FSharp.Objects.ArraysExtensions
1111
open GraphBLAS.FSharp.Objects.ClContextExtensions
1212

1313
module Matrix =
14+
/// <summary>
15+
/// Creates new COO matrix with the values from the given one.
16+
/// </summary>
17+
/// <param name="clContext">OpenCL context.</param>
18+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
19+
let copy (clContext: ClContext) workGroupSize =
20+
21+
let copy = ClArray.copy clContext workGroupSize
22+
23+
let copyData = ClArray.copy clContext workGroupSize
24+
25+
fun (processor: MailboxProcessor<_>) allocationMode (matrix: COO<'a>) ->
26+
{ Context = clContext
27+
RowCount = matrix.RowCount
28+
ColumnCount = matrix.ColumnCount
29+
Rows = copy processor allocationMode matrix.Rows
30+
Columns = copy processor allocationMode matrix.Columns
31+
Values = copyData processor allocationMode matrix.Values }
32+
1433
/// <summary>
1534
/// Builds a new COO matrix whose elements are the results of applying the given function
1635
/// to each of the elements of the matrix.
@@ -85,7 +104,7 @@ module Matrix =
85104
/// </summary>
86105
/// <param name="clContext">OpenCL context.</param>
87106
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
88-
let private compressRows (clContext: ClContext) workGroupSize =
107+
let compressRows (clContext: ClContext) workGroupSize =
89108

90109
let compressRows =
91110
<@ fun (ndRange: Range1D) (rows: ClArray<int>) (nnz: int) (rowPointers: ClArray<int>) ->
@@ -251,3 +270,62 @@ module Matrix =
251270
Values = clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, values)
252271
RowCount = rowCount
253272
ColumnCount = columnCount }
273+
274+
/// <summary>
275+
/// Returns matrix composed of all elements from the given row range of the input matrix.
276+
/// </summary>
277+
/// <param name="clContext">OpenCL context.</param>
278+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
279+
let subRows (clContext: ClContext) workGroupSize =
280+
281+
let upperBound = ClArray.upperBound clContext workGroupSize
282+
283+
let blit = ClArray.blit clContext workGroupSize
284+
285+
let blitData = ClArray.blit clContext workGroupSize
286+
287+
fun (processor: MailboxProcessor<_>) allocationMode startRow count (matrix: ClMatrix.COO<'a>) ->
288+
if count <= 0 then
289+
failwith "Count must be greater than zero"
290+
291+
if startRow < 0 then
292+
failwith "startIndex must be greater then zero"
293+
294+
if startRow + count > matrix.RowCount then
295+
failwith "startIndex and count sum is larger than the matrix row count"
296+
297+
let firstRowClCell = clContext.CreateClCell(startRow - 1)
298+
let lastRowClCell = clContext.CreateClCell(startRow + count)
299+
300+
// extract rows
301+
let firstIndex = (upperBound processor matrix.Rows firstRowClCell).ToHostAndFree processor
302+
let lastIndex = (upperBound processor matrix.Rows lastRowClCell).ToHostAndFree processor - 1
303+
304+
firstRowClCell.Free processor
305+
lastRowClCell.Free processor
306+
307+
let resultLength = lastIndex - firstIndex + 1
308+
309+
let rows =
310+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
311+
312+
blit processor matrix.Columns firstIndex rows 0 resultLength
313+
314+
// extract values
315+
let values =
316+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
317+
318+
blitData processor matrix.Values firstIndex values 0 resultLength
319+
320+
// extract indices
321+
let columns =
322+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
323+
324+
blit processor matrix.Columns firstIndex columns 0 resultLength
325+
326+
{ Context = clContext
327+
RowCount = matrix.RowCount
328+
ColumnCount = matrix.ColumnCount
329+
Rows = rows
330+
Columns = columns
331+
Values = values }

src/GraphBLAS-sharp.Backend/Operations/Operations.fs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,88 @@ module Operations =
9898
|> Some
9999
| _ -> failwith "Vector formats are not matching."
100100

101+
/// <summary>
102+
/// Applying the given function to the corresponding elements of the two given arrays pairwise.
103+
/// Stores the result in the left vector.
104+
/// </summary>
105+
/// <remarks>
106+
/// The two input arrays must have the same lengths.
107+
/// </remarks>
108+
/// <param name="map">The function to transform the pairs of the input elements.</param>
109+
/// <param name="clContext">OpenCL context.</param>
110+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
111+
let map2InPlace (map: Expr<'a option -> 'b option -> 'a option>) (clContext: ClContext) workGroupSize =
112+
let map2Dense =
113+
Dense.Vector.map2InPlace map clContext workGroupSize
114+
115+
fun (processor: MailboxProcessor<_>) (leftVector: ClVector<'a>) (rightVector: ClVector<'b>) ->
116+
match leftVector, rightVector with
117+
| ClVector.Dense left, ClVector.Dense right -> map2Dense processor left right left
118+
| _ -> failwith "Unsupported vector format"
119+
120+
/// <summary>
121+
/// Applying the given function to the corresponding elements of the two given arrays pairwise.
122+
/// Stores the result in the given vector.
123+
/// </summary>
124+
/// <remarks>
125+
/// The two input arrays must have the same lengths.
126+
/// </remarks>
127+
/// <param name="map">The function to transform the pairs of the input elements.</param>
128+
/// <param name="clContext">OpenCL context.</param>
129+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
130+
let map2To (map: Expr<'a option -> 'b option -> 'c option>) (clContext: ClContext) workGroupSize =
131+
let map2Dense =
132+
Dense.Vector.map2InPlace map clContext workGroupSize
133+
134+
fun (processor: MailboxProcessor<_>) (leftVector: ClVector<'a>) (rightVector: ClVector<'b>) (resultVector: ClVector<'c>) ->
135+
match leftVector, rightVector, resultVector with
136+
| ClVector.Dense left, ClVector.Dense right, ClVector.Dense result -> map2Dense processor left right result
137+
| _ -> failwith "Unsupported vector format"
138+
139+
/// <summary>
140+
/// Applying the given function to the corresponding elements of the two given arrays pairwise.
141+
/// Returns new vector.
142+
/// </summary>
143+
/// <remarks>
144+
/// The two input arrays must have the same lengths.
145+
/// </remarks>
146+
/// <param name="map">The function to transform the pairs of the input elements.</param>
147+
/// <param name="clContext">OpenCL context.</param>
148+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
149+
let map2Dense (map: Expr<'a option -> 'b option -> 'a option>) (clContext: ClContext) workGroupSize =
150+
let map2Dense =
151+
Dense.Vector.map2 map clContext workGroupSize
152+
153+
fun (processor: MailboxProcessor<_>) allocationFlag (leftVector: ClVector<'a>) (rightVector: ClVector<'b>) ->
154+
match leftVector, rightVector with
155+
| ClVector.Dense left, ClVector.Dense right -> map2Dense processor allocationFlag left right
156+
| _ -> failwith "Unsupported vector format"
157+
158+
/// <summary>
159+
/// Applying the given function to the corresponding elements of the two given arrays pairwise.
160+
/// Returns new vector as option.
161+
/// </summary>
162+
/// <remarks>
163+
/// The two input arrays must have the same lengths.
164+
/// </remarks>
165+
/// <param name="map">The function to transform the pairs of the input elements.</param>
166+
/// <param name="clContext">OpenCL context.</param>
167+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
168+
let map2Sparse (map: Expr<'a option -> 'b option -> 'a option>) (clContext: ClContext) workGroupSize =
169+
let map2Sparse =
170+
Sparse.Map2.run map clContext workGroupSize
171+
172+
let map2SparseDense =
173+
Sparse.Map2.runSparseDense map clContext workGroupSize
174+
175+
fun (processor: MailboxProcessor<_>) allocationFlag (leftVector: ClVector<'a>) (rightVector: ClVector<'b>) ->
176+
match leftVector, rightVector with
177+
| ClVector.Sparse left, ClVector.Sparse right ->
178+
Option.map ClVector.Sparse (map2Sparse processor allocationFlag left right)
179+
| ClVector.Sparse left, ClVector.Dense right ->
180+
Option.map ClVector.Sparse (map2SparseDense processor allocationFlag left right)
181+
| _ -> failwith "Unsupported vector format"
182+
101183
module Matrix =
102184
/// <summary>
103185
/// Builds a new matrix whose elements are the results of applying the given function
@@ -374,3 +456,43 @@ module Operations =
374456

375457
run processor allocationMode resultCapacity leftMatrix rightMatrix
376458
| _ -> failwith "Matrix formats are not matching"
459+
460+
module COO =
461+
/// <summary>
462+
/// Generalized matrix-matrix multiplication. Left matrix should be in COO format.
463+
/// </summary>
464+
/// <param name="opAdd">Type of binary function to reduce entries.</param>
465+
/// <param name="opMul">Type of binary function to combine entries.</param>
466+
/// <param name="clContext">OpenCL context.</param>
467+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
468+
let expand
469+
(opAdd: Expr<'c -> 'c -> 'c option>)
470+
(opMul: Expr<'a -> 'b -> 'c option>)
471+
(clContext: ClContext)
472+
workGroupSize
473+
=
474+
475+
let run =
476+
SpGeMM.Expand.COO.run opAdd opMul clContext workGroupSize
477+
478+
fun (processor: MailboxProcessor<_>) allocationMode (leftMatrix: ClMatrix<'a>) (rightMatrix: ClMatrix<'b>) ->
479+
match leftMatrix, rightMatrix with
480+
| ClMatrix.COO leftMatrix, ClMatrix.CSR rightMatrix ->
481+
let allocCapacity =
482+
List.max [ sizeof<'a>
483+
sizeof<'c>
484+
sizeof<'b> ]
485+
|> uint64
486+
|> (*) 1UL<Byte>
487+
488+
let resultCapacity =
489+
(clContext.MaxMemAllocSize / allocCapacity) / 3UL
490+
491+
let resultCapacity =
492+
(min
493+
<| uint64 System.Int32.MaxValue
494+
<| resultCapacity)
495+
|> int
496+
497+
run processor allocationMode resultCapacity leftMatrix rightMatrix
498+
| _ -> failwith "Matrix formats are not matching"

0 commit comments

Comments
 (0)