Skip to content

Commit be13d6e

Browse files
committed
add: Expand test
1 parent 5479816 commit be13d6e

10 files changed

Lines changed: 466 additions & 133 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/Gather.fs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ module internal Gather =
1717
let run (clContext: ClContext) workGroupSize =
1818

1919
let gather =
20-
<@ fun (ndRange: Range1D) (positions: ClArray<int>) (inputArray: ClArray<'a>) (outputArray: ClArray<'a>) (size: int) ->
20+
<@ fun (ndRange: Range1D) (positions: ClArray<int>) (values: ClArray<'a>) (outputArray: ClArray<'a>) (size: int) ->
2121

2222
let i = ndRange.GlobalID0
2323

2424
if i < size then
25-
outputArray.[i] <- inputArray.[positions.[i]] @>
25+
let position = positions.[i]
26+
let value = values.[position]
27+
28+
outputArray.[i] <- value @>
2629

2730
let program = clContext.Compile(gather)
2831

src/GraphBLAS-sharp.Backend/Common/Scatter.fs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ module internal Scatter =
2727
let run =
2828
<@ fun (ndRange: Range1D) (positions: ClArray<int>) (positionsLength: int) (values: ClArray<'a>) (result: ClArray<'a>) (resultLength: int) ->
2929

30-
let i = ndRange.GlobalID0
30+
let gid = ndRange.GlobalID0
3131

32-
if i < positionsLength then
33-
let index = positions.[i]
32+
if gid < positionsLength then
33+
let index = positions.[gid]
3434

3535
if 0 <= index && index < resultLength then
36-
if i < positionsLength - 1 then
37-
if index <> positions.[i + 1] then
38-
result.[index] <- values.[i]
36+
if gid < positionsLength - 1 then
37+
if index <> positions.[gid + 1] then
38+
result.[index] <- values.[gid]
3939
else
40-
result.[index] <- values.[i] @>
40+
result.[index] <- values.[gid] @>
4141

4242
let program = clContext.Compile(run)
4343

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ module Expand =
230230
removeDuplications processor globalRightMatrixRowsStartPositions
231231

232232
// RESULT row pointers into result expanded (obtained by multiplication) array
233+
// printfn "GLOBAL LENGTH: %A" globalLength
234+
233235
let resultRowPointers =
234236
getRowPointers processor globalLength leftMatrix.RowPointers globalRightMatrixRowsStartPositions
235237

@@ -239,6 +241,8 @@ module Expand =
239241
let globalMap =
240242
getGlobalPositions processor globalLength globalRightMatrixRawsPointersWithoutDuplicates
241243

244+
// printfn "global clmap: %A" <| globalMap.ToHost processor
245+
242246
globalMap, globalRightMatrixRawsPointersWithoutDuplicates, requiredLeftMatrixValues, requiredRightMatrixRawPointers, resultRowPointers
243247

244248
let expandRightMatrixValuesIndices (clContext: ClContext) workGroupSize =
@@ -250,7 +254,7 @@ module Expand =
250254

251255
if gid < length then
252256
// index corresponding to the position of pointers
253-
let positionIndex = globalPositions.[gid] - 1
257+
let positionIndex = globalPositions.[gid] - 1 // TODO()
254258

255259
// the position of the beginning of a new line of pointers
256260
let sourcePosition = globalRightMatrixValuesPositions.[positionIndex]
@@ -303,7 +307,7 @@ module Expand =
303307

304308
// globalBitmap.Length == resultValues.Length
305309
if gid < resultLength then
306-
let valueIndex = globalBitmap.[gid] - 1
310+
let valueIndex = globalBitmap.[gid] - 1 //TODO()
307311

308312
resultValues.[gid] <- leftMatrixValues.[valueIndex] @>
309313

src/GraphBLAS-sharp/Objects/Matrix.fs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ module Matrix =
1919
sprintf "Values: %A \n" this.Values ]
2020
|> String.concat ""
2121

22+
member this.NNZ = this.Values.Length
23+
2224
static member FromTuples(rowCount: int, columnCount: int, rows: int [], columns: int [], values: 'a []) =
2325
{ RowCount = rowCount
2426
ColumnCount = columnCount
@@ -79,6 +81,8 @@ module Matrix =
7981
RowCount = rowsCount
8082
ColumnCount = columnsCount }
8183

84+
member this.NNZ = this.Values.Length
85+
8286
member this.ToDevice(context: ClContext) =
8387
{ Context = context
8488
RowCount = this.RowCount
@@ -121,6 +125,8 @@ module Matrix =
121125
RowCount = rowsCount
122126
ColumnCount = columnsCount }
123127

128+
member this.NNZ = this.Values.Length
129+
124130
member this.ToDevice(context: ClContext) =
125131
{ Context = context
126132
RowCount = this.RowCount
@@ -154,9 +160,9 @@ type Matrix<'a when 'a: struct> =
154160

155161
member this.NNZ =
156162
match this with
157-
| COO m -> m.Values.Length
158-
| CSR m -> m.Values.Length
159-
| CSC m -> m.Values.Length
163+
| COO m -> m.NNZ
164+
| CSR m -> m.NNZ
165+
| CSC m -> m.NNZ
160166

161167
member this.ToDevice(context: ClContext) =
162168
match this with

tests/GraphBLAS-sharp.Tests/Generators.fs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ module Generators =
3838
/// Generates empty matrices as well.
3939
/// </remarks>
4040
let dimension2DGenerator =
41-
Gen.sized
42-
<| fun size -> Gen.choose (1, size) |> Gen.two
41+
fun size -> Gen.choose (1, size)
42+
|> Gen.sized
43+
|> Gen.two
4344

4445
let dimension3DGenerator =
45-
Gen.sized
46-
<| fun size -> Gen.choose (1, size) |> Gen.three
46+
fun size -> Gen.choose (1, size)
47+
|> Gen.sized
48+
|> Gen.three
4749

4850
let rec normalFloat32Generator (random: System.Random) =
4951
gen {
@@ -384,6 +386,9 @@ module Generators =
384386
valuesGenerator
385387
|> Gen.array2DOfDim (nColsA, nColsB)
386388

389+
printf $"left matrix column count: %A{Array2D.length1 matrixA}"
390+
printf $"right matrix row count: %A{Array2D.length2 matrixA}"
391+
387392
return (matrixA, matrixB)
388393
}
389394

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
<Compile Include="Matrix/Map2.fs" />
4646
<Compile Include="Matrix/Mxm.fs" />
4747
<Compile Include="Matrix/Transpose.fs" />
48+
<Compile Include="Matrix\SpGEMM\Example.fs" />
4849
<Compile Include="Matrix\SpGEMM\Expand.fs" />
4950
<Compile Include="Program.fs" />
5051
</ItemGroup>

tests/GraphBLAS-sharp.Tests/Helpers.fs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,63 @@ module Utils =
129129

130130
result
131131

132+
let prefixSumExclude (array: 'a []) zero plus =
133+
let mutable sum = zero
134+
135+
for i in 0 .. array.Length - 1 do
136+
let currentItem = array.[i]
137+
array.[i] <- sum
138+
139+
sum <- plus currentItem sum
140+
141+
sum
142+
143+
let prefixSumInclude (array: 'a []) zero plus =
144+
let mutable sum = zero
145+
146+
for i in 0 .. array.Length - 1 do
147+
sum <- plus array.[i] sum
148+
149+
array.[i] <- sum
150+
151+
sum
152+
153+
let getUniqueBitmap<'a when 'a: equality> (array: 'a []) =
154+
let bitmap = Array.zeroCreate array.Length
155+
156+
for i in 0 .. array.Length - 2 do
157+
if array.[i] <> array.[i + 1] then bitmap.[i] <- 1
158+
159+
// set last 1
160+
bitmap.[bitmap.Length - 1] <- 1
161+
162+
bitmap
163+
164+
let scatter (positions: int array) (values: 'a array) (resultValues: 'a array) =
165+
for i in 0 .. positions.Length - 2 do
166+
if positions.[i] <> positions.[i + 1] then
167+
let valuePosition = positions.[i]
168+
let value = values.[i]
169+
170+
resultValues.[valuePosition] <- value
171+
172+
// set last value
173+
let lastPosition = positions.[positions.Length - 1]
174+
let lastValue = values.[values.Length - 1]
175+
176+
resultValues.[lastPosition] <- lastValue
177+
178+
let gather (positions: int []) (values: 'a []) (result: 'a []) =
179+
for i in 0 .. positions.Length do
180+
let position = positions.[i]
181+
let value = values.[position]
182+
183+
result.[position] <- value
184+
185+
let castMatrixToCSR = function
186+
| Matrix.CSR matrix -> matrix
187+
| _ -> failwith "matrix format must be CSR"
188+
132189
module Context =
133190
type TestContext =
134191
{ ClContext: ClContext

0 commit comments

Comments
 (0)