Skip to content

Commit ab7d3c3

Browse files
committed
refactor: flags in bfs
1 parent 1450728 commit ab7d3c3

5 files changed

Lines changed: 18 additions & 14 deletions

File tree

src/GraphBLAS-sharp.Backend/Algorithms/BFS.fs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ open GraphBLAS.FSharp.Backend.Common
88
open GraphBLAS.FSharp.Backend.Quotes
99
open GraphBLAS.FSharp.Backend.Vector
1010
open GraphBLAS.FSharp.Backend.Vector.Dense
11+
open GraphBLAS.FSharp.Backend.Objects.ClContext
1112
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
1213

1314
module BFS =
@@ -24,7 +25,7 @@ module BFS =
2425
let zeroCreate =
2526
ClArray.zeroCreate clContext workGroupSize
2627

27-
let ofList = Vector.ofList clContext Dense
28+
let ofList = Vector.ofList clContext workGroupSize
2829

2930
let maskComplementedTo =
3031
DenseVector.elementWiseTo clContext Mask.complementedMaskOp workGroupSize
@@ -40,7 +41,8 @@ module BFS =
4041

4142
let levels = zeroCreate queue vertexCount
4243

43-
let frontier = ofList vertexCount [ source, 1 ]
44+
let frontier =
45+
ofList queue Dense GPUOnly vertexCount [ source, 1 ]
4446

4547
match frontier with
4648
| ClVector.Dense front ->

src/GraphBLAS-sharp.Backend/Objects/ClContextExtensions.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module ClContext =
88
| CPUInterop
99

1010
type ClContext with
11-
member this.CreateClArrayWithFlag(mode, (size: int)) =
11+
member this.CreateClArrayWithFlag(mode, size: int) =
1212
match mode with
1313
| GPUOnly ->
1414
this.CreateClArray(
@@ -25,7 +25,7 @@ module ClContext =
2525
allocationMode = AllocationMode.Default
2626
)
2727

28-
member this.CreateClArrayWithFlag(mode, (array: 'a [])) =
28+
member this.CreateClArrayWithFlag(mode, array: 'a []) =
2929
match mode with
3030
| GPUOnly ->
3131
this.CreateClArray(

src/GraphBLAS-sharp.Backend/Vector/Vector.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ module Vector =
1717
let zeroCreate =
1818
ClArray.zeroCreate clContext workGroupSize
1919

20-
fun (processor: MailboxProcessor<_>) (size: int) (format: VectorFormat) (flag: AllocationFlag) ->
20+
fun (processor: MailboxProcessor<_>) size format flag ->
2121
match format with
2222
| Sparse ->
2323
ClVector.Sparse
@@ -35,7 +35,7 @@ module Vector =
3535
let map =
3636
ClArray.map clContext workGroupSize <@ Some @>
3737

38-
fun (processor: MailboxProcessor<_>) (format: VectorFormat) flag size (elements: (int * 'a) list) ->
38+
fun (processor: MailboxProcessor<_>) format flag size (elements: (int * 'a) list) ->
3939
let indices, values =
4040
elements
4141
|> Array.ofList

tests/GraphBLAS-sharp.Tests/Vector/OfList.fs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ let checkResult
3232

3333
let correctnessGenericTest<'a when 'a: struct>
3434
(isEqual: 'a -> 'a -> bool)
35-
(ofList: VectorFormat -> int -> (int * 'a) list -> ClVector<'a>)
35+
(ofList: MailboxProcessor<_> -> VectorFormat -> ClContext.AllocationFlag -> int -> (int * 'a) list -> ClVector<'a>)
3636
(toCoo: MailboxProcessor<_> -> ClVector<'a> -> ClVector<'a>)
3737
(case: OperationCase<VectorFormat>)
3838
(elements: (int * 'a) [])
@@ -54,7 +54,8 @@ let correctnessGenericTest<'a when 'a: struct>
5454

5555
let actualSize = (Array.max indices) + abs sizeDelta + 1
5656

57-
let clActual = ofList case.Format actualSize elements
57+
let clActual =
58+
ofList q case.Format ClContext.CPUInterop actualSize elements
5859

5960
let clCooActual = toCoo q clActual
6061

@@ -78,15 +79,15 @@ let testFixtures (case: OperationCase<VectorFormat>) =
7879
let getCorrectnessTestName datatype =
7980
sprintf "Correctness on %s, %A" datatype case.Format
8081

81-
let boolOfList = Vector.ofList context
82+
let boolOfList = Vector.ofList context wgSize
8283

8384
let toCoo = Vector.toSparse context wgSize
8485

8586
case
8687
|> correctnessGenericTest<bool> (=) boolOfList toCoo
8788
|> testPropertyWithConfig config (getCorrectnessTestName "bool")
8889

89-
let intOfList = Vector.ofList context
90+
let intOfList = Vector.ofList context wgSize
9091

9192
let toCoo = Vector.toSparse context wgSize
9293

@@ -95,15 +96,15 @@ let testFixtures (case: OperationCase<VectorFormat>) =
9596
|> testPropertyWithConfig config (getCorrectnessTestName "int")
9697

9798

98-
let byteOfList = Vector.ofList context
99+
let byteOfList = Vector.ofList context wgSize
99100

100101
let toCoo = Vector.toSparse context wgSize
101102

102103
case
103104
|> correctnessGenericTest<byte> (=) byteOfList toCoo
104105
|> testPropertyWithConfig config (getCorrectnessTestName "byte")
105106

106-
let floatOfList = Vector.ofList context
107+
let floatOfList = Vector.ofList context wgSize
107108

108109
let toCoo = Vector.toSparse context wgSize
109110

tests/GraphBLAS-sharp.Tests/Vector/ZeroCreate.fs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,16 @@ let checkResult size (actual: Vector<'a>) =
2727
Expect.equal vector.Indices [| 0 |] "The index array must contain the 0"
2828

2929
let correctnessGenericTest<'a when 'a: struct and 'a: equality>
30-
(zeroCreate: MailboxProcessor<_> -> int -> VectorFormat -> ClVector<'a>)
30+
(zeroCreate: MailboxProcessor<_> -> int -> VectorFormat -> ClContext.AllocationFlag -> ClVector<'a>)
3131
(case: OperationCase<VectorFormat>)
3232
(vectorSize: int)
3333
=
3434

3535
if vectorSize > 0 then
3636
let q = case.TestContext.Queue
3737

38-
let (clVector: ClVector<'a>) = zeroCreate q vectorSize case.Format
38+
let (clVector: ClVector<'a>) =
39+
zeroCreate q vectorSize case.Format ClContext.CPUInterop
3940

4041
let hostVector = clVector.ToHost q
4142

0 commit comments

Comments
 (0)