Skip to content

Commit c89e1b9

Browse files
committed
add: CSR.map2
1 parent 7349c1d commit c89e1b9

7 files changed

Lines changed: 182 additions & 40 deletions

File tree

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
<Compile Include="Quotes/PreparePositions.fs" />
2727
<Compile Include="Quotes/Predicates.fs" />
2828
<Compile Include="Quotes/Map.fs" />
29-
<Compile Include="Quotes\BinSearch.fs" />
29+
<Compile Include="Quotes/BinSearch.fs" />
3030
<Compile Include="Common/Scatter.fs" />
3131
<Compile Include="Common/Utils.fs" />
3232
<Compile Include="Common/Sum.fs" />
@@ -40,6 +40,7 @@
4040
<Compile Include="Matrix/COOMatrix/Map2AtLeastOne.fs" />
4141
<Compile Include="Matrix/COOMatrix/Map.fs" />
4242
<Compile Include="Matrix/COOMatrix/Matrix.fs" />
43+
<Compile Include="Matrix/CSRMatrix/Map2.fs" />
4344
<Compile Include="Matrix/CSRMatrix/Map2AtLeastOne.fs" />
4445
<Compile Include="Matrix/CSRMatrix/SpGEMM.fs" />
4546
<Compile Include="Matrix/CSRMatrix/Map.fs" />

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/Map.fs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ open GraphBLAS.FSharp.Backend.Objects.ClMatrix
1111
open GraphBLAS.FSharp.Backend.Objects.ClContext
1212

1313
module internal Map =
14-
let preparePositions<'a, 'b> (clContext: ClContext) workGroupSize opAdd =
14+
let preparePositions<'a, 'b> (clContext: ClContext) workGroupSize op =
1515

1616
let preparePositions (op: Expr<'a option -> 'b option>) =
1717
<@ fun (ndRange: Range1D) rowCount columnCount (values: ClArray<'a>) (rowPointers: ClArray<int>) (columns: ClArray<int>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'b>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) ->
@@ -38,8 +38,7 @@ module internal Map =
3838
resultBitmap.[gid] <- 1
3939
| None -> resultBitmap.[gid] <- 0 @>
4040

41-
let kernel =
42-
clContext.Compile <| preparePositions opAdd
41+
let kernel = clContext.Compile <| preparePositions op
4342

4443
fun (processor: MailboxProcessor<_>) rowCount columnCount (values: ClArray<'a>) (rowPointers: ClArray<int>) (columns: ClArray<int>) ->
4544

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
namespace GraphBLAS.FSharp.Backend.Matrix.CSR
2+
3+
open Brahma.FSharp
4+
open Microsoft.FSharp.Quotations
5+
open GraphBLAS.FSharp.Backend
6+
open GraphBLAS.FSharp.Backend.Matrix
7+
open GraphBLAS.FSharp.Backend.Quotes
8+
open GraphBLAS.FSharp.Backend.Objects
9+
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
10+
open GraphBLAS.FSharp.Backend.Objects.ClContext
11+
open GraphBLAS.FSharp.Backend.Matrix.COO
12+
13+
module internal Map2 =
14+
let preparePositions<'a, 'b, 'c> (clContext: ClContext) workGroupSize opAdd =
15+
16+
let preparePositions (op: Expr<'a option -> 'b option -> 'c option>) =
17+
<@ fun (ndRange: Range1D) rowCount columnCount (leftValues: ClArray<'a>) (leftRowPointers: ClArray<int>) (leftColumns: ClArray<int>) (rightValues: ClArray<'b>) (rightRowPointers: ClArray<int>) (rightColumn: ClArray<int>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'c>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) ->
18+
19+
let gid = ndRange.GlobalID0
20+
21+
if gid < rowCount * columnCount then
22+
23+
let columnIndex = gid % columnCount
24+
let rowIndex = gid / columnCount
25+
26+
let leftStartIndex = leftRowPointers.[rowIndex]
27+
let leftLastIndex = leftRowPointers.[rowIndex + 1] - 1
28+
29+
let rightStartIndex = rightRowPointers.[rowIndex]
30+
let rightLastIndex = rightRowPointers.[rowIndex + 1] - 1
31+
32+
let leftValue =
33+
(%BinSearch.searchInRange) leftStartIndex leftLastIndex columnIndex leftColumns leftValues
34+
35+
let rightValue =
36+
(%BinSearch.searchInRange) rightStartIndex rightLastIndex columnIndex rightColumn rightValues
37+
38+
match (%op) leftValue rightValue with
39+
| Some value ->
40+
resultValues.[gid] <- value
41+
resultRows.[gid] <- rowIndex
42+
resultColumns.[gid] <- columnIndex
43+
44+
resultBitmap.[gid] <- 1
45+
| None -> resultBitmap.[gid] <- 0 @>
46+
47+
let kernel =
48+
clContext.Compile <| preparePositions opAdd
49+
50+
fun (processor: MailboxProcessor<_>) rowCount columnCount (leftValues: ClArray<'a>) (leftRows: ClArray<int>) (leftColumns: ClArray<int>) (rightValues: ClArray<'b>) (rightRows: ClArray<int>) (rightColumns: ClArray<int>) ->
51+
52+
let (resultLength: int) = columnCount * rowCount
53+
54+
let resultBitmap =
55+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
56+
57+
let resultRows =
58+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
59+
60+
let resultColumns =
61+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
62+
63+
let resultValues =
64+
clContext.CreateClArrayWithSpecificAllocationMode<'c>(DeviceOnly, resultLength)
65+
66+
let ndRange =
67+
Range1D.CreateValid(resultLength, workGroupSize)
68+
69+
let kernel = kernel.GetKernel()
70+
71+
processor.Post(
72+
Msg.MsgSetArguments
73+
(fun () ->
74+
kernel.KernelFunc
75+
ndRange
76+
rowCount
77+
columnCount
78+
leftValues
79+
leftRows
80+
leftColumns
81+
rightValues
82+
rightRows
83+
rightColumns
84+
resultBitmap
85+
resultValues
86+
resultRows
87+
resultColumns)
88+
)
89+
90+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
91+
92+
resultBitmap, resultValues, resultRows, resultColumns
93+
94+
///<param name="clContext">.</param>
95+
///<param name="opAdd">.</param>
96+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
97+
let runToCOO<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
98+
(clContext: ClContext)
99+
(opAdd: Expr<'a option -> 'b option -> 'c option>)
100+
workGroupSize
101+
=
102+
103+
let map2 =
104+
preparePositions clContext workGroupSize opAdd
105+
106+
let setPositions =
107+
Common.setPositions<'c> clContext workGroupSize
108+
109+
fun (queue: MailboxProcessor<_>) allocationMode (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSR<'b>) ->
110+
111+
let bitmap, values, rows, columns =
112+
map2
113+
queue
114+
matrixLeft.RowCount
115+
matrixLeft.ColumnCount
116+
matrixLeft.Values
117+
matrixLeft.RowPointers
118+
matrixLeft.Columns
119+
matrixRight.Values
120+
matrixRight.RowPointers
121+
matrixRight.Columns
122+
123+
let resultRows, resultColumns, resultValues, _ =
124+
setPositions queue allocationMode rows columns values bitmap
125+
126+
queue.Post(Msg.CreateFreeMsg<_>(bitmap))
127+
queue.Post(Msg.CreateFreeMsg<_>(values))
128+
queue.Post(Msg.CreateFreeMsg<_>(rows))
129+
queue.Post(Msg.CreateFreeMsg<_>(columns))
130+
131+
{ Context = clContext
132+
RowCount = matrixLeft.RowCount
133+
ColumnCount = matrixLeft.ColumnCount
134+
Rows = resultRows
135+
Columns = resultColumns
136+
Values = resultValues }
137+
138+
let run<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
139+
(clContext: ClContext)
140+
(opAdd: Expr<'a option -> 'b option -> 'c option>)
141+
workGroupSize
142+
=
143+
144+
let map2ToCOO = runToCOO clContext opAdd workGroupSize
145+
146+
let toCSRInplace =
147+
Matrix.toCSRInplace clContext workGroupSize
148+
149+
fun (queue: MailboxProcessor<_>) allocationMode (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSR<'b>) ->
150+
map2ToCOO queue allocationMode matrixLeft matrixRight
151+
|> toCSRInplace queue allocationMode

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/Matrix.fs

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -92,31 +92,7 @@ module Matrix =
9292

9393
let map = CSR.Map.run
9494

95-
let map2<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
96-
(clContext: ClContext)
97-
(opAdd: Expr<'a option -> 'b option -> 'c option>)
98-
workGroupSize
99-
=
100-
101-
let firstToCOO = toCOO clContext workGroupSize
102-
103-
let secondToCOO = toCOO clContext workGroupSize
104-
105-
let COOMap2 =
106-
COO.Matrix.map2 clContext opAdd workGroupSize
107-
108-
let toCSR =
109-
COO.Matrix.toCSRInplace clContext workGroupSize
110-
111-
fun (processor: MailboxProcessor<_>) allocationMode (leftMatrix: ClMatrix.CSR<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
112-
let leftCOOMatrix =
113-
firstToCOO processor DeviceOnly leftMatrix
114-
115-
let rightCOOMatrix =
116-
secondToCOO processor DeviceOnly rightMatrix
117-
118-
COOMap2 processor DeviceOnly leftCOOMatrix rightCOOMatrix
119-
|> toCSR processor allocationMode
95+
let map2 = Map2.run
12096

12197
let map2AtLeastOneToCOO<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
12298
(clContext: ClContext)

src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
open GraphBLAS.FSharp.Backend.Objects
44

55
module ArithmeticOperations =
6-
let inline mkOpWithConst zero op constant =
6+
let inline mkUnaryOp zero unaryOp =
77
<@ fun x ->
88
let mutable res = zero
99

1010
match x with
11-
| Some v -> res <- (op v constant)
12-
| None -> res <- constant
11+
| Some v -> res <- (%unaryOp) v
12+
| None -> res <- (%unaryOp) zero
1313

1414
if res = zero then None else Some res @>
1515

src/GraphBLAS-sharp.Backend/Quotes/BinSearch.fs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
open Brahma.FSharp
44

55
module BinSearch =
6+
/// <summary>
7+
/// Searches a section of the array of indices, bounded by the given left and right edges, for an index, using a binary search algorithm.
8+
/// In case searched section contains source index, the value at the same position in the array of values is returned.
9+
/// </summary>
10+
/// <remarks>
11+
/// Searched section of index array should be sorted in ascending order.
12+
/// The index array should have the same length as the array of values.
13+
/// left edge and right edge should be less than the length of the index array.
14+
/// </remarks>
615
let searchInRange<'a> =
716
<@ fun leftEdge rightEdge sourceIndex (indices: ClArray<int>) (values: ClArray<'a>) ->
817

@@ -27,6 +36,13 @@ module BinSearch =
2736

2837
result @>
2938

39+
/// <summary>
40+
/// Searches matrix in COO format for a value, using a binary search algorithm.
41+
/// In case there is a value at the given position, it is returned.
42+
/// </summary>
43+
/// <remarks>
44+
/// Position is uint64 and it should be written in such format: first 32 bits is row, second 32 bits is column.
45+
/// </remarks>
3046
let searchCOO<'a> =
3147
<@ fun lenght sourceIndex (rowIndices: ClArray<int>) (columnIndices: ClArray<int>) (values: ClArray<'a>) ->
3248

tests/GraphBLAS-sharp.Tests/Matrix/Map.fs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ let createTestMap case (zero: 'a) op isEqual opQ map =
102102
|> testPropertyWithConfig config (getCorrectnessTestName $"{typeof<'a>}")
103103

104104
let testFixturesMapNot case =
105-
[ let context = case.TestContext.ClContext
106-
let q = case.TestContext.Queue
105+
[ let q = case.TestContext.Queue
107106
q.Error.Add(fun e -> failwithf "%A" e)
108107

109108
createTestMap case false not (=) ArithmeticOperations.notQ Matrix.map ]
@@ -117,13 +116,13 @@ let testFixturesMapAdd case =
117116
q.Error.Add(fun e -> failwithf "%A" e)
118117

119118
let addFloat64Q =
120-
ArithmeticOperations.mkOpWithConst 0.0 (+) 10.0
119+
ArithmeticOperations.mkUnaryOp 0.0 <@ fun x -> x + 10.0 @>
121120

122121
let addFloat32Q =
123-
ArithmeticOperations.mkOpWithConst 0.0f (+) 10.0f
122+
ArithmeticOperations.mkUnaryOp 0.0f <@ fun x -> x + 10.0f @>
124123

125124
let addByte =
126-
ArithmeticOperations.mkOpWithConst 0uy (+) 10uy
125+
ArithmeticOperations.mkUnaryOp 0uy <@ fun x -> x + 10uy @>
127126

128127
if Utils.isFloat64Available context.ClDevice then
129128
createTestMap case 0.0 ((+) 10.0) Utils.floatIsEqual addFloat64Q Matrix.map
@@ -140,13 +139,13 @@ let testFixturesMapMul case =
140139
q.Error.Add(fun e -> failwithf "%A" e)
141140

142141
let mulFloat64Q =
143-
ArithmeticOperations.mkOpWithConst 0.0 (*) 10.0
142+
ArithmeticOperations.mkUnaryOp 0.0 <@ fun x -> x * 10.0 @>
144143

145144
let mulFloat32Q =
146-
ArithmeticOperations.mkOpWithConst 0.0f (*) 10.0f
145+
ArithmeticOperations.mkUnaryOp 0.0f <@ fun x -> x * 10.0f @>
147146

148147
let mulByte =
149-
ArithmeticOperations.mkOpWithConst 0uy (*) 10uy
148+
ArithmeticOperations.mkUnaryOp 0uy <@ fun x -> x * 10uy @>
150149

151150
if Utils.isFloat64Available context.ClDevice then
152151
createTestMap case 0.0 ((*) 10.0) Utils.floatIsEqual mulFloat64Q Matrix.map

0 commit comments

Comments
 (0)