Skip to content

Commit 3d2cc7f

Browse files
committed
add: CSR.subByRows
1 parent 1ebf84b commit 3d2cc7f

9 files changed

Lines changed: 325 additions & 51 deletions

File tree

benchmarks/GraphBLAS-sharp.Benchmarks/Matrix/SpGeMM/Expand.fs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,14 @@ module WithoutTransfer =
134134
override this.GlobalCleanup () =
135135
this.ClearInputMatrices()
136136

137-
type Float32() =
138-
139-
inherit Benchmark<float32>(
140-
Matrix.SpGeMM.expand (fst ArithmeticOperations.float32Add) (fst ArithmeticOperations.float32Mul),
141-
float32,
142-
(fun _ -> Utils.nextSingle (System.Random())),
143-
(fun context matrix -> ClMatrix.CSR <| matrix.ToCSR.ToDevice context)
144-
)
145-
146-
static member InputMatrixProvider =
147-
Benchmarks<_>.InputMatrixProviderBuilder "SpGeMM.txt"
137+
// type Float32() =
138+
//
139+
// inherit Benchmark<float32>(
140+
// Matrix.SpGeMM.expand (fst ArithmeticOperations.float32Add) (fst ArithmeticOperations.float32Mul),
141+
// float32,
142+
// (fun _ -> Utils.nextSingle (System.Random())),
143+
// (fun context matrix -> ClMatrix.CSR <| matrix.ToCSR.ToDevice context)
144+
// )
145+
//
146+
// static member InputMatrixProvider =
147+
// Benchmarks<_>.InputMatrixProviderBuilder "SpGeMM.txt"

src/GraphBLAS-sharp.Backend/Matrix/CSR/Matrix.fs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,84 @@ module Matrix =
5151

5252
rows
5353

54+
let subRows (clContext: ClContext) workGroupSize =
55+
56+
let kernel =
57+
<@ fun (ndRange: Range1D) resultLength sourceRow pointersLength (pointers: ClArray<int>) (results: ClArray<int>) ->
58+
59+
let gid = ndRange.GlobalID0
60+
61+
let shift = pointers.[sourceRow]
62+
let shiftedId = gid + shift
63+
64+
if gid < resultLength then
65+
let result =
66+
(%Search.Bin.lowerBound 0) pointersLength shiftedId pointers
67+
68+
results.[gid] <- result - 1 @>
69+
70+
let program = clContext.Compile kernel
71+
72+
let blit = ClArray.blit clContext workGroupSize
73+
74+
let blitData = ClArray.blit clContext workGroupSize
75+
76+
fun (processor: MailboxProcessor<_>) allocationMode startIndex count (matrix: ClMatrix.CSR<'a>) ->
77+
if count <= 0 then
78+
failwith "Count must be greater than zero"
79+
80+
if startIndex < 0 then
81+
failwith "startIndex must be greater then zero"
82+
83+
if startIndex + count > matrix.RowCount then
84+
failwith "startIndex and count sum is larger than the matrix row count"
85+
86+
// extract rows
87+
let rowPointers = matrix.RowPointers.ToHost processor
88+
89+
let resultLength = rowPointers.[startIndex + count] - rowPointers.[startIndex]
90+
91+
let rows =
92+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
93+
94+
let kernel = program.GetKernel()
95+
96+
let ndRange =
97+
Range1D.CreateValid(matrix.Columns.Length, workGroupSize)
98+
99+
processor.Post(Msg.MsgSetArguments(
100+
fun () ->
101+
kernel.KernelFunc
102+
ndRange
103+
resultLength
104+
startIndex
105+
matrix.RowPointers.Length
106+
matrix.RowPointers
107+
rows))
108+
109+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
110+
111+
let startPosition = rowPointers.[startIndex]
112+
113+
// extract values
114+
let values =
115+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
116+
117+
blitData processor matrix.Values startPosition values 0 resultLength
118+
119+
// extract indices
120+
let columns =
121+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
122+
123+
blit processor matrix.Columns startPosition columns 0 resultLength
124+
125+
{ Context = clContext
126+
RowCount = matrix.RowCount
127+
ColumnCount = matrix.ColumnCount
128+
Rows = rows
129+
Columns = columns
130+
Values = values }
131+
54132
let toCOO (clContext: ClContext) workGroupSize =
55133
let prepare = expandRowPointers clContext workGroupSize
56134

src/GraphBLAS-sharp/Objects/MatrixExtensions.fs

Lines changed: 95 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,105 @@ open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
77
open GraphBLAS.FSharp.Objects.ClVectorExtensions
88

99
module MatrixExtensions =
10+
// Matrix.Free
11+
type ClMatrix.COO<'a when 'a : struct> with
12+
member this.Free(q: MailboxProcessor<_>) =
13+
this.Columns.Free q
14+
this.Values.Free q
15+
this.Rows.Free q
16+
17+
member this.ToHost(q: MailboxProcessor<_>) =
18+
{ RowCount = this.RowCount
19+
ColumnCount = this.ColumnCount
20+
Rows = this.Rows.ToHost q
21+
Columns = this.Columns.ToHost q
22+
Values = this.Values.ToHost q }
23+
24+
member this.ToHostAndFree(q: MailboxProcessor<_>) =
25+
let result = this.ToHost q
26+
this.Free q
27+
28+
result
29+
30+
type ClMatrix.CSR<'a when 'a : struct> with
31+
member this.Free(q: MailboxProcessor<_>) =
32+
this.Values.Free q
33+
this.Columns.Free q
34+
this.RowPointers.Free q
35+
36+
member this.ToHost(q: MailboxProcessor<_>) =
37+
{ RowCount = this.RowCount
38+
ColumnCount = this.ColumnCount
39+
RowPointers = this.RowPointers.ToHost q
40+
ColumnIndices = this.Columns.ToHost q
41+
Values = this.Values.ToHost q }
42+
43+
member this.ToHostAndFree(q: MailboxProcessor<_>) =
44+
let result = this.ToHost q
45+
this.Free q
46+
47+
result
48+
49+
type ClMatrix.CSC<'a when 'a : struct> with
50+
member this.Free(q: MailboxProcessor<_>) =
51+
this.Values.Free q
52+
this.Rows.Free q
53+
this.ColumnPointers.Free q
54+
55+
member this.ToHost(q: MailboxProcessor<_>) =
56+
{ RowCount = this.RowCount
57+
ColumnCount = this.ColumnCount
58+
RowIndices = this.Rows.ToHost q
59+
ColumnPointers = this.ColumnPointers.ToHost q
60+
Values = this.Values.ToHost q }
61+
62+
member this.ToHostAndFree(q: MailboxProcessor<_>) =
63+
let result = this.ToHost q
64+
this.Free q
65+
66+
result
67+
68+
type ClMatrix.LIL<'a when 'a : struct> with
69+
member this.Free(q: MailboxProcessor<_>) =
70+
this.Rows
71+
|> List.iter (Option.iter (fun row -> row.Dispose q))
72+
73+
member this.ToHost(q: MailboxProcessor<_>) =
74+
{ RowCount = this.RowCount
75+
ColumnCount = this.ColumnCount
76+
Rows =
77+
this.Rows
78+
|> List.map (Option.map (fun row -> row.ToHost q))
79+
NNZ = this.NNZ }
80+
81+
member this.ToHostAndFree(q: MailboxProcessor<_>) =
82+
let result = this.ToHost q
83+
this.Free q
84+
85+
result
86+
1087
type ClMatrix<'a when 'a: struct> with
1188
member this.ToHost(q: MailboxProcessor<_>) =
1289
match this with
13-
| ClMatrix.COO m ->
14-
{ RowCount = m.RowCount
15-
ColumnCount = m.ColumnCount
16-
Rows = m.Rows.ToHost q
17-
Columns = m.Columns.ToHost q
18-
Values = m.Values.ToHost q }
19-
|> Matrix.COO
20-
| ClMatrix.CSR m ->
21-
{ RowCount = m.RowCount
22-
ColumnCount = m.ColumnCount
23-
RowPointers = m.RowPointers.ToHost q
24-
ColumnIndices = m.Columns.ToHost q
25-
Values = m.Values.ToHost q }
26-
|> Matrix.CSR
27-
| ClMatrix.CSC m ->
28-
{ RowCount = m.RowCount
29-
ColumnCount = m.ColumnCount
30-
RowIndices = m.Rows.ToHost q
31-
ColumnPointers = m.ColumnPointers.ToHost q
32-
Values = m.Values.ToHost q }
33-
|> Matrix.CSC
34-
| ClMatrix.LIL m ->
35-
{ RowCount = m.RowCount
36-
ColumnCount = m.ColumnCount
37-
Rows =
38-
m.Rows
39-
|> List.map (Option.map (fun row -> row.ToHost q))
40-
NNZ = m.NNZ }
41-
|> Matrix.LIL
42-
43-
member this.ToHostAndDispose(processor: MailboxProcessor<_>) =
90+
| ClMatrix.COO m -> m.ToHost q |> Matrix.COO
91+
| ClMatrix.CSR m -> m.ToHost q |> Matrix.CSR
92+
| ClMatrix.CSC m -> m.ToHost q |> Matrix.CSC
93+
| ClMatrix.LIL m -> m.ToHost q |> Matrix.LIL
94+
95+
member this.Free(q: MailboxProcessor<_>) =
96+
match this with
97+
| ClMatrix.COO m -> m.Free q
98+
| ClMatrix.CSR m -> m.Free q
99+
| ClMatrix.CSC m -> m.Free q
100+
| ClMatrix.LIL m -> m.Free q
101+
102+
member this.FreeAndWait(processor: MailboxProcessor<_>) =
103+
this.Free processor
104+
processor.PostAndReply(MsgNotifyMe)
105+
106+
member this.ToHostAndFree(processor: MailboxProcessor<_>) =
44107
let result = this.ToHost processor
45108

46-
this.Dispose processor
109+
this.Free processor
47110

48111
result
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Matrix.SubRows
2+
3+
open Expecto
4+
open GraphBLAS.FSharp.Test
5+
open GraphBLAS.FSharp.Tests
6+
open GraphBLAS.FSharp.Backend
7+
open GraphBLAS.FSharp.Objects
8+
open GraphBLAS.FSharp.Backend.Matrix
9+
open GraphBLAS.FSharp.Backend.Objects.ClContext
10+
open GraphBLAS.FSharp.Backend.Objects
11+
open GraphBLAS.FSharp.Objects.MatrixExtensions
12+
open GraphBLAS.FSharp.Objects.Matrix
13+
14+
let context = Context.defaultContext.ClContext
15+
16+
let processor = Context.defaultContext.Queue
17+
18+
let config = { Utils.defaultConfig with arbitrary = [ typeof<Generators.Matrix.Sub> ] }
19+
20+
let makeTest isEqual zero testFun (array: 'a [,], sourceRow, count) =
21+
22+
let matrix = Matrix.CSR.FromArray2D(array, isEqual zero)
23+
24+
if matrix.NNZ > 0 then
25+
26+
let clMatrix = matrix.ToDevice context
27+
28+
let clActual: ClMatrix.COO<'a> = testFun processor HostInterop sourceRow count clMatrix
29+
30+
let actual = clActual.ToHostAndFree processor
31+
32+
let expected =
33+
array
34+
|> Array2D.mapi (fun rowIndex columnIndex value -> (value, rowIndex, columnIndex))
35+
|> fun array -> array.[sourceRow .. sourceRow + count - 1, *]
36+
|> Seq.cast<'a * int * int>
37+
|> Seq.filter (fun (value, _, _) -> (not <| isEqual zero value))
38+
|> Seq.toArray
39+
|> Array.unzip3
40+
|> fun (values, rows, columns) ->
41+
{ RowCount = Array2D.length1 array
42+
ColumnCount = Array2D.length2 array
43+
Rows = rows
44+
Columns = columns
45+
Values = values }
46+
47+
Utils.compareCOOMatrix isEqual actual expected
48+
49+
let createTest isEqual (zero: 'a) =
50+
CSR.Matrix.subRows context Utils.defaultWorkGroupSize
51+
|> makeTest isEqual zero
52+
|> testPropertyWithConfig config $"test on {typeof<'a>}"
53+
54+
let tests =
55+
[ createTest (=) 0
56+
57+
if Utils.isFloat64Available context.ClDevice then
58+
createTest (=) 0.0
59+
60+
createTest (=) 0.0f
61+
createTest (=) false ]
62+
|> testList "Blit"

tests/GraphBLAS-sharp.Tests/Generators.fs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,3 +1159,62 @@ module Generators =
11591159
static member BoolType() =
11601160
pairOfVectorsOfEqualSize <| Arb.generate<bool>
11611161
|> Arb.fromGen
1162+
1163+
module Matrix =
1164+
type Sub() =
1165+
static let arrayAndChunkPosition (valuesGenerator: Gen<'a>) =
1166+
gen {
1167+
let! rowsCount = Gen.sized <| fun size -> Gen.choose (2, size + 2)
1168+
let! columnsCount = Gen.sized <| fun size -> Gen.choose (1, size + 1)
1169+
1170+
let! array = Gen.array2DOfDim (rowsCount, columnsCount) valuesGenerator
1171+
1172+
let! startPosition = Gen.choose (0, rowsCount - 2)
1173+
let! count = Gen.choose (1, rowsCount - startPosition - 1)
1174+
1175+
return (array, startPosition, count)
1176+
}
1177+
1178+
static member IntType() =
1179+
arrayAndChunkPosition <| Arb.generate<int>
1180+
|> Arb.fromGen
1181+
1182+
static member FloatType() =
1183+
arrayAndChunkPosition
1184+
<| (Arb.Default.NormalFloat()
1185+
|> Arb.toGen
1186+
|> Gen.map float)
1187+
|> Arb.fromGen
1188+
1189+
static member Float32Type() =
1190+
arrayAndChunkPosition
1191+
<| (normalFloat32Generator <| System.Random())
1192+
|> Arb.fromGen
1193+
1194+
static member SByteType() =
1195+
arrayAndChunkPosition <| Arb.generate<sbyte>
1196+
|> Arb.fromGen
1197+
1198+
static member ByteType() =
1199+
arrayAndChunkPosition <| Arb.generate<byte>
1200+
|> Arb.fromGen
1201+
1202+
static member Int16Type() =
1203+
arrayAndChunkPosition <| Arb.generate<int16>
1204+
|> Arb.fromGen
1205+
1206+
static member UInt16Type() =
1207+
arrayAndChunkPosition <| Arb.generate<uint16>
1208+
|> Arb.fromGen
1209+
1210+
static member Int32Type() =
1211+
arrayAndChunkPosition <| Arb.generate<int32>
1212+
|> Arb.fromGen
1213+
1214+
static member UInt32Type() =
1215+
arrayAndChunkPosition <| Arb.generate<uint32>
1216+
|> Arb.fromGen
1217+
1218+
static member BoolType() =
1219+
arrayAndChunkPosition <| Arb.generate<bool>
1220+
|> Arb.fromGen

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
<Compile Include="Backend/Matrix/Transpose.fs" />
5050
<Compile Include="Backend/Matrix/Merge.fs" />
5151
<Compile Include="Backend\Matrix\ExpandRows.fs" />
52+
<Compile Include="Backend\Matrix\SubRows.fs" />
5253
<Compile Include="Backend/Vector/AssignByMask.fs" />
5354
<Compile Include="Backend/Vector/Convert.fs" />
5455
<Compile Include="Backend/Vector/Copy.fs" />

0 commit comments

Comments
 (0)