Skip to content

Commit 1f2b95b

Browse files
committed
Implement some vector constructors; add mask complementation
1 parent ecb4b79 commit 1f2b95b

7 files changed

Lines changed: 88 additions & 12 deletions

File tree

benchmarks/GraphBLAS-sharp.Benchmarks/Program.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ let main argv =
66
let benchmarks = BenchmarkSwitcher [|
77
// typeof<EWiseAddBenchmarks4Float32>
88
// typeof<EWiseAddBenchmarks4Bool>
9-
// typeof<BFSBenchmarks>
10-
typeof<QGBenchmarks>
9+
typeof<BFSBenchmarks>
10+
// typeof<QGBenchmarks>
1111
|]
1212

1313
benchmarks.Run argv |> ignore

src/GraphBLAS-sharp/Backend/CSRMatrix/SpMSpV.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ module internal rec SpMSpV =
177177
}
178178

179179
elif mask.Indices.Length = 0 && not mask.IsComplemented ||
180-
mask.Indices.Length = mask.Size && mask.IsComplemented then
180+
mask.Indices.Length = mask.Size && mask.IsComplemented then
181181
return {
182182
Size = matrix.RowCount
183183
Indices = [||]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
namespace GraphBLAS.FSharp.Backend.Mask
2+
3+
open Brahma.FSharp.OpenCL.WorkflowBuilder.Basic
4+
open Brahma.FSharp.OpenCL.WorkflowBuilder.Evaluation
5+
open GraphBLAS.FSharp
6+
open GraphBLAS.FSharp.Backend.Common
7+
open Brahma.OpenCL
8+
9+
module internal GetComplemented =
10+
let mask1D (mask: Mask1D) = opencl {
11+
let size = mask.Size
12+
let nnz = mask.Indices.Length
13+
14+
let bitmap = Array.create size 1
15+
let getComplementedBitmap =
16+
<@
17+
fun (range: _1D)
18+
(maskIndices: int[])
19+
(bitmap: int[]) ->
20+
21+
let gid = range.GlobalID0
22+
23+
if gid < nnz then
24+
let maskIdx = maskIndices.[gid]
25+
bitmap.[maskIdx] <- 0
26+
@>
27+
28+
do! RunCommand getComplementedBitmap <| fun kernelPrepare ->
29+
kernelPrepare
30+
<| _1D(Utils.getDefaultGlobalSize nnz, Utils.defaultWorkGroupSize)
31+
<| mask.Indices
32+
<| bitmap
33+
34+
let! (positions, _) = PrefixSum.runExclude bitmap
35+
36+
let complementedIndices = Array.zeroCreate<int> (size - nnz)
37+
let setPosotions =
38+
<@
39+
fun (range: _1D)
40+
(positions: int[])
41+
(bitmap: int[])
42+
(complementedIndices: int[]) ->
43+
44+
let gid = range.GlobalID0
45+
46+
if gid < size && bitmap.[gid] = 1 then
47+
complementedIndices.[positions.[gid]] <- gid
48+
@>
49+
50+
do! RunCommand setPosotions <| fun kernelPrepare ->
51+
kernelPrepare
52+
<| _1D(Utils.getDefaultGlobalSize size, Utils.defaultWorkGroupSize)
53+
<| positions
54+
<| bitmap
55+
<| complementedIndices
56+
57+
return Mask1D(complementedIndices, size, not mask.IsComplemented)
58+
}

src/GraphBLAS-sharp/GraphBLAS-sharp.fsproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
<Compile Include="Backend/COOVector/EWiseAdd.fs" />
4545
<Compile Include="Backend/COOVector/FillSubVector.fs" />
4646
<Compile Include="Backend/COOVector/AssignSubVector.fs" />
47+
<Compile Include="Backend/Mask/GetComplemented.fs" />
4748
<Compile Include="Operations/Matrix.fs" />
4849
<Compile Include="Operations/Vector.fs" />
4950
<Compile Include="Operations/Scalar.fs" />
@@ -63,4 +64,4 @@
6364
</Content>
6465
</ItemGroup>
6566
<Import Project="..\..\.paket\Paket.Restore.targets" />
66-
</Project>
67+
</Project>

src/GraphBLAS-sharp/Operations/Scalar.fs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ module Scalar =
99
constructors
1010
*)
1111

12-
let create (value: 'a) : GraphblasEvaluation<Scalar<'a>> = graphblas { return ScalarWrapped { Value = [| value |] } }
12+
let create (value: 'a) : GraphblasEvaluation<Scalar<'a>> =
13+
graphblas { return ScalarWrapped { Value = [| value |] } }
1314

1415
(*
1516
methods

src/GraphBLAS-sharp/Operations/Vector.fs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@ module Vector =
1818
failwith "Not Implemented yet"
1919

2020
let ofList (size: int) (elements: (int * 'a) list) : GraphblasEvaluation<Vector<'a>> =
21-
failwith "Not Implemented yet"
21+
let (indices, values) =
22+
elements
23+
|> Array.ofList
24+
|> Array.sortBy fst
25+
|> Array.unzip
26+
27+
graphblas { return VectorCOO <| COOVector.FromTuples(size, indices, values) }
2228

2329
// можно оставить, но с условием, что будет создаваться full vector
2430
// let ofArray (array: 'a[]) : GraphblasEvaluation<Vector<'a>> =
@@ -31,7 +37,7 @@ module Vector =
3137
failwith "Not Implemented yet"
3238

3339
let zeroCreate<'a when 'a : struct> (size: int) : GraphblasEvaluation<Vector<'a>> =
34-
failwith "Not Implemented yet"
40+
graphblas { return VectorCOO <| COOVector.FromTuples(size, [||], [||]) }
3541

3642
(*
3743
methods
@@ -77,8 +83,9 @@ module Vector =
7783
match vector with
7884
| VectorCOO vector ->
7985
opencl {
80-
let! resultIndices = Copy.copyArray vector.Indices
81-
return Mask1D(resultIndices, vector.Size, true)
86+
let! indices = Copy.copyArray vector.Indices
87+
let! complementedMask = Mask.GetComplemented.mask1D <| Mask1D(indices, vector.Size, true)
88+
return complementedMask
8289
}
8390
|> EvalGB.fromCl
8491

@@ -106,9 +113,19 @@ module Vector =
106113
let extractValue (vector: Vector<'a>) (idx: int) : GraphblasEvaluation<Scalar<'a>> =
107114
failwith "Not Implemented yet"
108115

116+
// assignToVector
109117
/// t <- vec
110118
let assignVector (target: Vector<'a>) (source: Vector<'a>) : GraphblasEvaluation<unit> =
111-
failwith "Not Implemented yet"
119+
if target.Size <> source.Size then
120+
invalidArg "source" <| sprintf "The size of source vector must be %A. Received: %A" target.Size source.Size
121+
122+
match source, target with
123+
| VectorCOO source, VectorCOO target ->
124+
opencl {
125+
target.Indices <- source.Indices
126+
target.Values <- source.Values
127+
}
128+
|> EvalGB.fromCl
112129

113130
/// t.[mask] <- vec
114131
let assignSubVector (target: Vector<'a>) (mask: Mask1D) (source: Vector<'a>) : GraphblasEvaluation<unit> =

tests/GraphBLAS-sharp.Tests/MatrixOperationsTests/MxvTests.fs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ let tests =
228228

229229
deviceType = DeviceType.Cpu &&
230230
case.MatrixCase = CSR &&
231-
case.VectorCase = VectorFormat.COO &&
232-
case.MaskCase <> Complemented
231+
case.VectorCase = VectorFormat.COO
233232
)
234233
|> List.collect testFixtures
235234
|> testList "Matrix.mxv tests"

0 commit comments

Comments
 (0)