Skip to content

Commit 5571002

Browse files
committed
refactor: tests, binSearch
1 parent 146b17f commit 5571002

7 files changed

Lines changed: 94 additions & 131 deletions

File tree

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
<Compile Include="Quotes/PreparePositions.fs" />
2727
<Compile Include="Quotes/Predicates.fs" />
2828
<Compile Include="Quotes/Map.fs" />
29+
<Compile Include="Quotes\BinSearch.fs" />
2930
<Compile Include="Common/Scatter.fs" />
3031
<Compile Include="Common/Utils.fs" />
3132
<Compile Include="Common/Sum.fs" />

src/GraphBLAS-sharp.Backend/Matrix/COOMatrix/Map.fs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22

33
open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend.Matrix
5+
open GraphBLAS.FSharp.Backend.Quotes
56
open Microsoft.FSharp.Quotations
67
open GraphBLAS.FSharp.Backend.Objects
78
open GraphBLAS.FSharp.Backend
89
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
910
open GraphBLAS.FSharp.Backend.Objects.ClContext
1011

1112

12-
module Map =
13+
module internal Map =
1314
let preparePositions<'a, 'b> (clContext: ClContext) workGroupSize opAdd =
1415

1516
let preparePositions (op: Expr<'a option -> 'b option>) =
16-
<@ fun (ndRange: Range1D) rowCount columnCount valuesLength (values: ClArray<'a>) (rowPointers: ClArray<int>) (columns: ClArray<int>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'b>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) ->
17+
<@ fun (ndRange: Range1D) rowCount columnCount valuesLength (values: ClArray<'a>) (rows: ClArray<int>) (columns: ClArray<int>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'b>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) ->
1718

1819
let gid = ndRange.GlobalID0
1920

@@ -26,7 +27,7 @@ module Map =
2627
(uint64 rowIndex <<< 32) ||| (uint64 columnIndex)
2728

2829
let value =
29-
(%Map2.binSearch) valuesLength index rowPointers columns values
30+
(%BinSearch.searchCOO) valuesLength index rows columns values
3031

3132
match (%op) value with
3233
| Some resultValue ->
@@ -37,7 +38,6 @@ module Map =
3738
resultBitmap.[gid] <- 1
3839
| None -> resultBitmap.[gid] <- 0 @>
3940

40-
4141
let kernel =
4242
clContext.Compile <| preparePositions opAdd
4343

@@ -83,7 +83,6 @@ module Map =
8383

8484
resultBitmap, resultValues, resultRows, resultColumns
8585

86-
8786
let run<'a, 'b when 'a: struct and 'b: struct and 'b: equality>
8887
(clContext: ClContext)
8988
(opAdd: Expr<'a option -> 'b option>)

src/GraphBLAS-sharp.Backend/Matrix/COOMatrix/Map2.fs

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,11 @@ open GraphBLAS.FSharp.Backend.Matrix
55
open Microsoft.FSharp.Quotations
66
open GraphBLAS.FSharp.Backend.Objects
77
open GraphBLAS.FSharp.Backend
8+
open GraphBLAS.FSharp.Backend.Quotes
89
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
910
open GraphBLAS.FSharp.Backend.Objects.ClContext
1011

1112
module internal Map2 =
12-
let binSearch<'a> =
13-
<@ fun lenght sourceIndex (rowIndices: ClArray<int>) (columnIndices: ClArray<int>) (values: ClArray<'a>) ->
14-
15-
let mutable leftEdge = 0
16-
let mutable rightEdge = lenght - 1
17-
18-
let mutable result = None
19-
20-
while leftEdge <= rightEdge do
21-
let middleIdx = (leftEdge + rightEdge) / 2
22-
23-
let currentIndex: uint64 =
24-
((uint64 rowIndices.[middleIdx]) <<< 32)
25-
||| (uint64 columnIndices.[middleIdx])
26-
27-
if sourceIndex = currentIndex then
28-
result <- Some values.[middleIdx]
29-
30-
rightEdge <- -1 // TODO() break
31-
elif sourceIndex < currentIndex then
32-
rightEdge <- middleIdx - 1
33-
else
34-
leftEdge <- middleIdx + 1
35-
36-
result @>
37-
3813
let preparePositions<'a, 'b, 'c> (clContext: ClContext) workGroupSize opAdd =
3914

4015
let preparePositions (op: Expr<'a option -> 'b option -> 'c option>) =
@@ -51,10 +26,10 @@ module internal Map2 =
5126
(uint64 rowIndex <<< 32) ||| (uint64 columnIndex)
5227

5328
let leftValue =
54-
(%binSearch) leftValuesLength index leftRows leftColumns leftValues
29+
(%BinSearch.searchCOO) leftValuesLength index leftRows leftColumns leftValues
5530

5631
let rightValue =
57-
(%binSearch) rightValuesLength index rightRows rightColumn rightValues
32+
(%BinSearch.searchCOO) rightValuesLength index rightRows rightColumn rightValues
5833

5934
match (%op) leftValue rightValue with
6035
| Some value ->

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/Map.fs

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,14 @@
33
open Brahma.FSharp
44
open FSharp.Quotations
55
open GraphBLAS.FSharp.Backend
6+
open GraphBLAS.FSharp.Backend.Quotes
67
open GraphBLAS.FSharp.Backend.Matrix
78
open GraphBLAS.FSharp.Backend.Matrix.COO
89
open GraphBLAS.FSharp.Backend.Objects
910
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
1011
open GraphBLAS.FSharp.Backend.Objects.ClContext
1112

12-
module Map =
13-
let binSearch<'a> =
14-
<@ fun startIndex nnzInRow sourceColumn (columnIndices: ClArray<int>) (values: ClArray<'a>) ->
15-
16-
let mutable leftEdge = startIndex
17-
let mutable rightEdge = startIndex + nnzInRow - 1
18-
19-
let mutable result = None
20-
21-
while leftEdge <= rightEdge do
22-
let middleIdx = (leftEdge + rightEdge) / 2
23-
24-
let currentColumn = columnIndices.[middleIdx]
25-
26-
if sourceColumn = currentColumn then
27-
result <- Some values.[middleIdx]
28-
29-
rightEdge <- -1 // TODO() break
30-
elif sourceColumn < currentColumn then
31-
rightEdge <- middleIdx - 1
32-
else
33-
leftEdge <- middleIdx + 1
34-
35-
result @>
36-
13+
module internal Map =
3714
let preparePositions<'a, 'b> (clContext: ClContext) workGroupSize opAdd =
3815

3916
let preparePositions (op: Expr<'a option -> 'b option>) =
@@ -46,12 +23,11 @@ module Map =
4623
let columnIndex = gid % columnCount
4724
let rowIndex = gid / columnCount
4825

49-
let nnzInRow =
50-
rowPointers.[rowIndex + 1]
51-
- rowPointers.[rowIndex]
26+
let startIndex = rowPointers.[rowIndex]
27+
let lastIndex = rowPointers.[rowIndex + 1] - 1
5228

5329
let value =
54-
(%binSearch) rowPointers.[rowIndex] nnzInRow columnIndex columns values
30+
(%BinSearch.searchInRange) startIndex lastIndex columnIndex columns values
5531

5632
match (%op) value with
5733
| Some resultValue ->
@@ -106,7 +82,6 @@ module Map =
10682

10783
resultBitmap, resultValues, resultRows, resultColumns
10884

109-
11085
let runToCOO<'a, 'b when 'a: struct and 'b: struct and 'b: equality>
11186
(clContext: ClContext)
11287
(opAdd: Expr<'a option -> 'b option>)
@@ -145,11 +120,11 @@ module Map =
145120
workGroupSize
146121
=
147122

148-
let elementwiseToCOO = runToCOO clContext opAdd workGroupSize
123+
let mapToCOO = runToCOO clContext opAdd workGroupSize
149124

150125
let toCSRInplace =
151126
Matrix.toCSRInplace clContext workGroupSize
152127

153128
fun (queue: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
154-
elementwiseToCOO queue allocationMode matrix
129+
mapToCOO queue allocationMode matrix
155130
|> toCSRInplace queue allocationMode

src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
open GraphBLAS.FSharp.Backend.Objects
44

55
module ArithmeticOperations =
6+
let inline mkOpWithConst zero op constant =
7+
<@ fun x ->
8+
let mutable res = zero
9+
10+
match x with
11+
| Some v -> res <- (op v constant)
12+
| None -> res <- constant
13+
14+
if res = zero then None else Some res @>
15+
616
let inline mkNumericSum zero =
717
<@ fun (x: 't option) (y: 't option) ->
818
let mutable res = zero
@@ -98,3 +108,9 @@ module ArithmeticOperations =
98108
let byteMulAtLeastOne = mkNumericMulAtLeastOne 0uy
99109
let floatMulAtLeastOne = mkNumericMulAtLeastOne 0.0
100110
let float32MulAtLeastOne = mkNumericMulAtLeastOne 0f
111+
112+
let notQ =
113+
<@ fun x ->
114+
match x with
115+
| Some true -> None
116+
| _ -> Some true @>
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
namespace GraphBLAS.FSharp.Backend.Quotes
2+
3+
open Brahma.FSharp
4+
5+
module BinSearch =
6+
let searchInRange<'a> =
7+
<@ fun leftEdge rightEdge sourceIndex (indices: ClArray<int>) (values: ClArray<'a>) ->
8+
9+
let mutable leftEdge = leftEdge
10+
let mutable rightEdge = rightEdge
11+
12+
let mutable result = None
13+
14+
while leftEdge <= rightEdge do
15+
let middleIdx = (leftEdge + rightEdge) / 2
16+
17+
let currentColumn = indices.[middleIdx]
18+
19+
if sourceIndex = currentColumn then
20+
result <- Some values.[middleIdx]
21+
22+
rightEdge <- -1 // TODO() break
23+
elif sourceIndex < currentColumn then
24+
rightEdge <- middleIdx - 1
25+
else
26+
leftEdge <- middleIdx + 1
27+
28+
result @>
29+
30+
let searchCOO<'a> =
31+
<@ fun lenght sourceIndex (rowIndices: ClArray<int>) (columnIndices: ClArray<int>) (values: ClArray<'a>) ->
32+
33+
let mutable leftEdge = 0
34+
let mutable rightEdge = lenght - 1
35+
36+
let mutable result = None
37+
38+
while leftEdge <= rightEdge do
39+
let middleIdx = (leftEdge + rightEdge) / 2
40+
41+
let currentIndex: uint64 =
42+
((uint64 rowIndices.[middleIdx]) <<< 32)
43+
||| (uint64 columnIndices.[middleIdx])
44+
45+
if sourceIndex = currentIndex then
46+
result <- Some values.[middleIdx]
47+
48+
rightEdge <- -1 // TODO() break
49+
elif sourceIndex < currentIndex then
50+
rightEdge <- middleIdx - 1
51+
else
52+
leftEdge <- middleIdx + 1
53+
54+
result @>
55+

tests/GraphBLAS-sharp.Tests/Matrix/Map.fs

Lines changed: 8 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ open Expecto.Logging
55
open Expecto.Logging.Message
66
open Microsoft.FSharp.Collections
77
open GraphBLAS.FSharp.Backend
8+
open GraphBLAS.FSharp.Backend.Quotes
89
open GraphBLAS.FSharp.Backend.Matrix
910
open GraphBLAS.FSharp.Backend.Objects
1011
open GraphBLAS.FSharp.Backend.Objects.ClContext
@@ -105,13 +106,7 @@ let testFixturesMapNot case =
105106
let q = case.TestContext.Queue
106107
q.Error.Add(fun e -> failwithf "%A" e)
107108

108-
let notQ =
109-
<@ fun x ->
110-
match x with
111-
| Some true -> None
112-
| _ -> Some true @>
113-
114-
createTestMap case false not (=) notQ Matrix.map ]
109+
createTestMap case false not (=) ArithmeticOperations.notQ Matrix.map ]
115110

116111
let notTests =
117112
operationGPUTests "Backend.Matrix.map not tests" testFixturesMapNot
@@ -121,35 +116,9 @@ let testFixturesMapAdd case =
121116
let q = case.TestContext.Queue
122117
q.Error.Add(fun e -> failwithf "%A" e)
123118

124-
let addFloat64Q =
125-
<@ fun x ->
126-
let mutable res = 0.0
127-
128-
match x with
129-
| Some v -> res <- (v + 10.0)
130-
| None -> res <- 10.0
131-
132-
if res = 0.0 then None else Some res @>
133-
134-
let addFloat32Q =
135-
<@ fun x ->
136-
let mutable res = 0.0f
137-
138-
match x with
139-
| Some v -> res <- (v + 10.0f)
140-
| None -> res <- 10.0f
141-
142-
if res = 0.0f then None else Some res @>
143-
144-
let addByte =
145-
<@ fun x ->
146-
let mutable res = 0uy
147-
148-
match x with
149-
| Some v -> res <- (v + 10uy)
150-
| None -> res <- 10uy
151-
152-
if res = 0uy then None else Some res @>
119+
let addFloat64Q = ArithmeticOperations.mkOpWithConst 0.0 (+) 10.0
120+
let addFloat32Q = ArithmeticOperations.mkOpWithConst 0.0f (+) 10.0f
121+
let addByte = ArithmeticOperations.mkOpWithConst 0uy (+) 10uy
153122

154123
if Utils.isFloat64Available context.ClDevice then
155124
createTestMap case 0.0 ((+) 10.0) Utils.floatIsEqual addFloat64Q Matrix.map
@@ -165,36 +134,9 @@ let testFixturesMapMul case =
165134
let q = case.TestContext.Queue
166135
q.Error.Add(fun e -> failwithf "%A" e)
167136

168-
let mulFloat64Q =
169-
<@ fun x ->
170-
let mutable res = 0.0
171-
172-
match x with
173-
| Some v -> res <- (v * 10.0)
174-
| _ -> ()
175-
176-
if res = 0.0 then None else Some res @>
177-
178-
let mulFloat32Q =
179-
<@ fun x ->
180-
let mutable res = 0.0f
181-
182-
match x with
183-
| Some v -> res <- (v * 10.0f)
184-
| _ -> ()
185-
186-
if res = 0.0f then None else Some res @>
187-
188-
189-
let mulByte =
190-
<@ fun x ->
191-
let mutable res = 0uy
192-
193-
match x with
194-
| Some v -> res <- (v * 10uy)
195-
| _ -> ()
196-
197-
if res = 0uy then None else Some res @>
137+
let mulFloat64Q = ArithmeticOperations.mkOpWithConst 0.0 (*) 10.0
138+
let mulFloat32Q = ArithmeticOperations.mkOpWithConst 0.0f (*) 10.0f
139+
let mulByte = ArithmeticOperations.mkOpWithConst 0uy (*) 10uy
198140

199141
if Utils.isFloat64Available context.ClDevice then
200142
createTestMap case 0.0 ((*) 10.0) Utils.floatIsEqual mulFloat64Q Matrix.map

0 commit comments

Comments
 (0)