Skip to content

Commit 751ee68

Browse files
committed
add: filter after multiplication
1 parent f34e590 commit 751ee68

7 files changed

Lines changed: 104 additions & 47 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ module ClArray =
389389

390390
result
391391

392-
let private assignOption2 (clContext: ClContext) workGroupSize (op: Expr<'a -> 'b -> 'c option>) =
392+
let assignOption2 (clContext: ClContext) workGroupSize (op: Expr<'a -> 'b -> 'c option>) =
393393

394394
let assign =
395395
<@ fun (ndRange: Range1D) length (firstValues: ClArray<'a>) (secondValues: ClArray<'b>) (positions: ClArray<int>) (result: ClArray<'c>) resultLength ->

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ open GraphBLAS.FSharp.Backend.Objects.ClContext
99
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
1010
open GraphBLAS.FSharp.Backend.Objects
1111
open GraphBLAS.FSharp.Backend.Objects.ClCell
12+
open FSharp.Quotations
1213

1314
type Indices = ClArray<int>
1415

@@ -65,7 +66,40 @@ module Expand =
6566

6667
length, segmentsLengths
6768

68-
let expand (clContext: ClContext) workGroupSize opMul =
69+
let multiply (clContext: ClContext) workGroupSize (predicate: Expr<'a -> 'b -> 'c option>) =
70+
let getBitmap =
71+
ClArray.map2<'a, 'b, int> clContext workGroupSize
72+
<| Map.chooseBitmap2 predicate
73+
74+
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
75+
76+
let assignValues = ClArray.assignOption2 clContext workGroupSize predicate
77+
78+
let scatter = Scatter.lastOccurrence clContext workGroupSize // TODO(last ?)
79+
80+
fun (processor: MailboxProcessor<_>) (firstValues: ClArray<'a>) (secondValues: ClArray<'b>) (columns: Indices) (rows: Indices) ->
81+
82+
let positions = getBitmap processor DeviceOnly firstValues secondValues
83+
84+
let resultLength =
85+
(prefixSum processor positions)
86+
.ToHostAndFree(processor)
87+
88+
let resultColumns = clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
89+
90+
scatter processor positions columns resultColumns
91+
92+
let resultRows = clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
93+
94+
scatter processor positions rows resultRows
95+
96+
let resultValues = clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
97+
98+
assignValues processor firstValues secondValues positions resultValues
99+
100+
resultValues, resultColumns, resultRows
101+
102+
let expand (clContext: ClContext) workGroupSize =
69103

70104
let idScatter = Scatter.initLastOccurrence Map.id clContext workGroupSize
71105

@@ -89,8 +123,6 @@ module Expand =
89123

90124
let BGather = Gather.run clContext workGroupSize
91125

92-
let mul = ClArray.map2 clContext workGroupSize opMul
93-
94126
fun (processor: MailboxProcessor<_>) lengths (segmentsPointers: Indices) (leftMatrix: ClMatrix.CSR<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
95127

96128
// Compute A positions
@@ -150,13 +182,8 @@ module Expand =
150182

151183
BPositions.Free processor
152184

153-
// multiply values TODO(filter values)
154-
let values = mul processor DeviceOnly AValues BValues
155-
156-
AValues.Free processor
157-
BValues.Free processor
158-
159-
values, columns, rows
185+
// left, right matrix values, columns and rows indices
186+
AValues, BValues, columns, rows
160187

161188
let sortByColumnsAndRows (clContext: ClContext) workGroupSize =
162189

@@ -227,7 +254,9 @@ module Expand =
227254

228255
let getSegmentPointers = getSegmentPointers clContext workGroupSize
229256

230-
let expand = expand clContext workGroupSize opMul
257+
let expand = expand clContext workGroupSize
258+
259+
let multiply = multiply clContext workGroupSize opMul
231260

232261
let sort = sortByColumnsAndRows clContext workGroupSize
233262

@@ -237,24 +266,37 @@ module Expand =
237266

238267
let length, segmentPointers = getSegmentPointers processor leftMatrix rightMatrix
239268

240-
let values, columns, rows =
269+
// expand
270+
let leftMatrixValues, rightMatrixValues, columns, rows =
241271
expand processor length segmentPointers leftMatrix rightMatrix
242272

243-
printfn $"expanded values: %A{values.ToHost processor}"
273+
printfn $"left matrix values: %A{leftMatrixValues.ToHost processor}"
274+
printfn $"right matrix values: %A{rightMatrixValues.ToHost processor}"
244275
printfn $"expanded columns: %A{columns.ToHost processor}"
245276
printfn $"expanded rows: %A{rows.ToHost processor}"
246277

278+
// multiply
279+
let resultValues, resultColumns, resultRows =
280+
multiply processor leftMatrixValues rightMatrixValues columns rows
281+
282+
leftMatrixValues.Free processor
283+
rightMatrixValues.Free processor
284+
columns.Free processor
285+
rows.Free processor
286+
287+
// sort
247288
let sortedValues, sortedColumns, sortedRows =
248-
sort processor values columns rows
289+
sort processor resultValues resultColumns resultRows
249290

250291
printfn $"sorted values: %A{sortedValues.ToHost processor}"
251292
printfn $"sorted columns: %A{sortedColumns.ToHost processor}"
252293
printfn $"sorted rows: %A{sortedRows.ToHost processor}"
253294

254-
values.Free processor
255-
columns.Free processor
256-
rows.Free processor
295+
resultValues.Free processor
296+
resultColumns.Free processor
297+
resultRows.Free processor
257298

299+
// addition
258300
let reducedValues, reducedColumns, reducedRows =
259301
reduce processor allocationMode sortedValues sortedColumns sortedRows
260302

src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,20 @@ module ArithmeticOperations =
141141

142142
let inline createPair zero op opQ = binOpQ zero opQ, binOp zero op
143143

144+
// addition
144145
let intAdd = createPair 0 (+) <@ (+) @>
145146

146147
let boolAdd = createPair false (||) <@ (||) @>
147148

148149
let floatAdd = createPair 0.0 (+) <@ (+) @>
149150

150151
let float32Add = createPair 0.0f (+) <@ (+) @>
152+
153+
// multiplication
154+
let intMul = createPair 0 (*) <@ (*) @>
155+
156+
let boolMul = createPair false (&&) <@ (&&) @>
157+
158+
let floatMul = createPair 0.0 (*) <@ (*) @>
159+
160+
let float32Mul = createPair 0.0f (*) <@ (*) @>

tests/GraphBLAS-sharp.Tests/Common/ClArray/Choose.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ let createTest2 (isEqual: 'a -> 'a -> bool) (opMapQ, opMap) testFun =
8585
let testFun = testFun context Utils.defaultWorkGroupSize opMapQ
8686

8787
makeTest2 isEqual opMap testFun
88-
|> testPropertyWithConfig { config with maxTest = 1000 } $"test on %A{typeof<'a>}"
88+
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
8989

9090
let tests2 =
9191
[ createTest2 (=) ArithmeticOperations.intAdd ClArray.choose2

tests/GraphBLAS-sharp.Tests/Helpers.fs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ module HostPrimitives =
229229

230230
result
231231

232-
let array2DMultiplication mul add leftArray rightArray =
232+
let array2DMultiplication zero mul add leftArray rightArray =
233233
if Array2D.length2 leftArray <> Array2D.length1 rightArray then
234234
failwith "Incompatible matrices"
235235

@@ -239,7 +239,8 @@ module HostPrimitives =
239239
<| fun i j ->
240240
(leftArray.[i, *], rightArray.[*, j])
241241
||> Array.map2 mul
242-
|> Array.reduce add
242+
|> Array.choose id
243+
|> Array.fold add zero
243244

244245
module Context =
245246
type TestContext =

tests/GraphBLAS-sharp.Tests/Matrix/SpGeMM.fs

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GraphBLAS.FSharp.Tests.Matrix.SpGeMM
22

33
open Expecto
44
open GraphBLAS.FSharp.Backend.Matrix.CSR.SpGeMM
5+
open GraphBLAS.FSharp.Backend.Quotes
56
open GraphBLAS.FSharp.Test
67
open Microsoft.FSharp.Collections
78
open GraphBLAS.FSharp.Backend
@@ -76,7 +77,7 @@ let getSegmentsTests =
7677
createTest ((=) 0uy) Expand.getSegmentPointers ]
7778
|> testList "get segment pointers"
7879

79-
let expand length segmentPointers mulOp (leftMatrix: Matrix.CSR<'a>) (rightMatrix: Matrix.CSR<'b>) =
80+
let expand length segmentPointers (leftMatrix: Matrix.CSR<'a>) (rightMatrix: Matrix.CSR<'b>) =
8081
let extendPointers pointers =
8182
Array.pairwise pointers
8283
|> Array.map (fun (fst, snd) -> snd - fst)
@@ -106,11 +107,9 @@ let expand length segmentPointers mulOp (leftMatrix: Matrix.CSR<'a>) (rightMatri
106107
|> Array.concat
107108
|> Array.unzip
108109

109-
let expectedValues = Array.map2 mulOp leftMatrixValues rightMatrixValues
110+
leftMatrixValues, rightMatrixValues, expectedColumns, expectedRows
110111

111-
expectedValues, expectedColumns, expectedRows
112-
113-
let makeExpandTest isEqual zero opMul testFun (leftArray: 'a [,], rightArray: 'a [,]) =
112+
let makeExpandTest isEqual zero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
114113

115114
let leftMatrix = createCSRMatrix leftArray <| isEqual zero
116115

@@ -126,51 +125,55 @@ let makeExpandTest isEqual zero opMul testFun (leftArray: 'a [,], rightArray: 'a
126125
let clRightMatrix = rightMatrix.ToDevice context
127126
let clSegmentPointers = context.CreateClArray segmentPointers
128127

129-
let (clActualValues: ClArray<'a>), (clActualColumns: ClArray<int>), (clActualRows: ClArray<int>) =
128+
let (clActualLeftValues: ClArray<'a>), (clActualRightValues: ClArray<'a>), (clActualColumns: ClArray<int>), (clActualRows: ClArray<int>) =
130129
testFun processor length clSegmentPointers clLeftMatrix clRightMatrix
131130

132131
clLeftMatrix.Dispose processor
133132
clRightMatrix.Dispose processor
134133
clSegmentPointers.Free processor
135134

136-
let actualValues = clActualValues.ToHostAndFree processor
135+
let actualLeftValues = clActualLeftValues.ToHostAndFree processor
136+
let actualRightValues = clActualRightValues.ToHostAndFree processor
137137
let actualColumns = clActualColumns.ToHostAndFree processor
138138
let actualRows = clActualRows.ToHostAndFree processor
139139

140-
let expectedValues, expectedColumns, expectedRows =
141-
expand length segmentPointers opMul leftMatrix rightMatrix
140+
let expectedLeftMatrixValues, expectedRightMatrixValues, expectedColumns, expectedRows =
141+
expand length segmentPointers leftMatrix rightMatrix
142+
143+
"Left values must be the same"
144+
|> Utils.compareArrays isEqual actualLeftValues expectedLeftMatrixValues
142145

143-
"Values must be the same"
144-
|> Utils.compareArrays isEqual actualValues expectedValues
146+
"Right values must be the same"
147+
|> Utils.compareArrays isEqual actualRightValues expectedRightMatrixValues
145148

146149
"Columns must be the same"
147150
|> Utils.compareArrays (=) actualColumns expectedColumns
148151

149152
"Rows must be the same"
150153
|> Utils.compareArrays (=) actualRows expectedRows
151154

152-
let createExpandTest isEqual (zero: 'a) opMul opMulQ testFun =
155+
let createExpandTest isEqual (zero: 'a) testFun =
153156

154-
let testFun = testFun context Utils.defaultWorkGroupSize opMulQ
157+
let testFun = testFun context Utils.defaultWorkGroupSize
155158

156-
makeExpandTest isEqual zero opMul testFun
159+
makeExpandTest isEqual zero testFun
157160
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
158161

159162
let expandTests =
160-
[ createExpandTest (=) 0 (*) <@ (*) @> Expand.expand
163+
[ createExpandTest (=) 0 Expand.expand
161164

162165
if Utils.isFloat64Available context.ClDevice then
163-
createExpandTest Utils.floatIsEqual 0.0 (*) <@ (*) @> Expand.expand
166+
createExpandTest Utils.floatIsEqual 0.0 Expand.expand
164167

165-
createExpandTest Utils.float32IsEqual 0f (*) <@ (*) @> Expand.expand
166-
createExpandTest (=) false (&&) <@ (&&) @> Expand.expand
167-
createExpandTest (=) 0uy (*) <@ (*) @> Expand.expand ]
168+
createExpandTest Utils.float32IsEqual 0f Expand.expand
169+
createExpandTest (=) false Expand.expand
170+
createExpandTest (=) 0uy Expand.expand ]
168171
|> testList "Expand.expand"
169172

170173
let checkGeneralResult zero isEqual actualValues actualColumns actualRows mul add (leftArray: 'a [,]) (rightArray: 'a [,]) =
171174

172175
let expected =
173-
HostPrimitives.array2DMultiplication mul add leftArray rightArray
176+
HostPrimitives.array2DMultiplication zero mul add leftArray rightArray
174177
|> fun array -> Utils.createMatrixFromArray2D COO array (isEqual zero)
175178
|> function Matrix.COO matrix -> matrix | _ -> failwith "format miss"
176179

@@ -217,15 +220,15 @@ let makeGeneralTest zero isEqual opMul opAdd testFun (leftArray: 'a [,], rightAr
217220
checkGeneralResult zero isEqual actualValues actualColumns actualRows opMul opAdd leftArray rightArray
218221
with
219222
| ex when ex.Message = "InvalidBufferSize" -> ()
220-
| ex -> raise ex
223+
| _ -> reraise ()
221224

222-
let createGeneralTest (zero: 'a) isEqual opAdd opAddQ opMul opMulQ testFun =
225+
let createGeneralTest (zero: 'a) isEqual opAddQ opAdd (opMulQ, opMul) testFun =
223226

224227
let testFun = testFun context Utils.defaultWorkGroupSize opAddQ opMulQ
225228

226229
makeGeneralTest zero isEqual opMul opAdd testFun
227230
|> testPropertyWithConfig { config with endSize = 10; maxTest = 1000 } $"test on %A{typeof<'a>}"
228231

229232
let generalTests =
230-
[ createGeneralTest 0 (=) (+) <@ (+) @> (*) <@ (*) @> Expand.run ]
233+
[ createGeneralTest 0 (=) <@ (+) @> (+) ArithmeticOperations.intMul Expand.run ]
231234
|> testList "general"

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,11 @@ open GraphBLAS.FSharp.Tests.Matrix
9494
let allTests =
9595
testList
9696
"All tests"
97-
[ // SpGeMM.generalTests
97+
[ // SpGeMM.expandTests
98+
SpGeMM.generalTests
9899
// Common.Gather.initTests
99-
Common.ClArray.Choose.tests2 ]
100-
100+
// Common.ClArray.Choose.tests2 ]
101+
]
101102
|> testSequenced
102103

103104
[<EntryPoint>]

0 commit comments

Comments
 (0)