Skip to content

Commit f76e8ac

Browse files
committed
add: DenseVector.elementwise
1 parent 04a09cc commit f76e8ac

5 files changed

Lines changed: 99 additions & 83 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/StandardOperations.fs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,20 @@ module StandardOperations =
102102
let floatMulAtLeastOne = mkNumericMulAtLeastOne 0.0
103103
let float32MulAtLeastOne = mkNumericMulAtLeastOne 0f
104104

105-
let maskAtLeastOne<'a, 'b when 'a: struct and 'b: struct> res =
106-
<@ fun (value: AtLeastOne<'a, 'b>) ->
107-
match value with
105+
let mask<'a, 'b when 'a: struct and 'b: struct> =
106+
<@ fun (left: 'a option) (right: 'b option) value ->
107+
match left, right with
108+
| _, None -> left
109+
| _ -> Some value @>
110+
111+
let maskAtLeastOne<'a, 'b when 'a: struct and 'b: struct> =
112+
<@ fun (pair: AtLeastOne<'a, 'b>) value ->
113+
match pair with
108114
| Left left -> Some left
109-
| _ -> Some res @>
115+
| _ -> Some value @>
110116

111-
let complementedMask<'a, 'b when 'a: struct and 'b: struct> res =
112-
<@ fun (left: 'a option) (right: 'b option) ->
117+
let complementedMask<'a, 'b when 'a: struct and 'b: struct> =
118+
<@ fun (left: 'a option) (right: 'b option) value ->
113119
match left, right with
114-
| Some left, Some _-> Some left
115-
| None, Some _ -> None
116-
| _ -> Some res @>
120+
| _, Some _-> left
121+
| _ -> Some value @>

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
<Compile Include="Matrix/CSRMatrix/SpMV.fs" />
3535
<Compile Include="Matrix/Matrix.fs" />
3636
<Compile Include="Vector/SparseVector/SparseVector.fs" />
37+
<Compile Include="Vector\DenseVector\ElementwiseQuotes.fs" />
3738
<Compile Include="Vector/DenseVector/DenseVector.fs" />
3839
<Compile Include="Vector/Vector.fs" />
3940
<!--Compile Include="Backend/CSRMatrix/GetTuples.fs" /-->

src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs

Lines changed: 19 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,7 @@ module DenseVector =
1212
(workGroupSize: int)
1313
=
1414

15-
let eWiseAdd =
16-
<@ fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
17-
18-
let gid = ndRange.GlobalID0
19-
20-
if gid < resultLength then
21-
resultVector.[gid] <- (%opAdd) leftVector.[gid] rightVector.[gid] @>
22-
23-
let kernel = clContext.Compile(eWiseAdd)
15+
let kernel = clContext.Compile(ElementwiseQuotes.kernel opAdd)
2416

2517
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) ->
2618

@@ -52,19 +44,7 @@ module DenseVector =
5244
(workGroupSize: int)
5345
=
5446

55-
let eWiseAdd =
56-
<@ fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
57-
58-
let gid = ndRange.GlobalID0
59-
60-
if gid < resultLength then
61-
match leftVector.[gid], rightVector.[gid] with
62-
| Some left, Some right -> resultVector.[gid] <- (%opAdd) (Both(left, right))
63-
| Some left, None -> resultVector.[gid] <- (%opAdd) (Left left)
64-
| None, Some right -> resultVector.[gid] <- (%opAdd) (Right right)
65-
| _ -> resultVector.[gid] <- None @>
66-
67-
let kernel = clContext.Compile(eWiseAdd)
47+
let kernel = clContext.Compile(ElementwiseQuotes.atLeastOneKernel opAdd)
6848

6949
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) ->
7050

@@ -90,66 +70,35 @@ module DenseVector =
9070

9171
resultVector
9272

93-
let fillSubVector<'a, 'b when 'a: struct and 'b: struct> (clContext: ClContext) (workGroupSize: int) (scalar: 'a) =
94-
95-
let eWiseAdd =
96-
elementWiseAtLeastOne clContext (StandardOperations.maskAtLeastOne scalar) workGroupSize
97-
98-
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (maskVector: ClArray<'b option>) ->
99-
100-
let clScalar = clContext.CreateClCell scalar
101-
102-
let resultVector = eWiseAdd processor leftVector maskVector
103-
104-
processor.Post(Msg.CreateFreeMsg<_>(maskVector))
105-
106-
processor.Post(Msg.CreateFreeMsg<_>(clScalar))
107-
108-
resultVector
109-
110-
let complemented<'a when 'a: struct> (clContext: ClContext) (workGroupSize: int) =
111-
112-
let complemented =
113-
<@ fun (ndRange: Range1D) length (inputArray: ClArray<'a option>) (defaultValue: ClCell<'a>) (resultArray: ClArray<'a option>) ->
114-
115-
let gid = ndRange.GlobalID0
116-
117-
if gid < length then
118-
match inputArray.[gid] with
119-
| None -> resultArray.[gid] <- Some defaultValue.Value
120-
| _ -> () @>
121-
122-
123-
let kernel = clContext.Compile(complemented)
124-
125-
let create =
126-
ClArray.zeroCreate clContext workGroupSize
127-
128-
fun (processor: MailboxProcessor<_>) (vector: ClArray<'a option>) ->
129-
130-
let length = vector.Length
73+
let fillSubVector<'a, 'b when 'a: struct and 'b: struct>
74+
(clContext: ClContext)
75+
(maskOp: Expr<'a option -> 'b option -> 'a -> 'a option>)
76+
(workGroupSize: int) =
13177

132-
let resultArray = create processor length
78+
let kernel = clContext.Compile(ElementwiseQuotes.fillSubVector maskOp)
13379

134-
let defaultValue =
135-
clContext.CreateClCell Unchecked.defaultof<'a>
80+
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (maskVector: ClArray<'b option>) (value: ClCell<'a>) ->
81+
let resultArray =
82+
clContext.CreateClArray(
83+
leftVector.Length,
84+
hostAccessMode = HostAccessMode.NotAccessible,
85+
deviceAccessMode = DeviceAccessMode.ReadWrite,
86+
allocationMode = AllocationMode.Default
87+
)
13688

13789
let ndRange =
138-
Range1D.CreateValid(length, workGroupSize)
90+
Range1D.CreateValid(leftVector.Length, workGroupSize)
13991

14092
let kernel = kernel.GetKernel()
14193

14294
processor.Post(
143-
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange length vector defaultValue resultArray)
95+
Msg.MsgSetArguments(fun () ->
96+
kernel.KernelFunc ndRange leftVector.Length leftVector maskVector value resultArray)
14497
)
14598

146-
processor.Post(Msg.CreateRunMsg(kernel))
147-
148-
processor.Post(Msg.CreateFreeMsg(defaultValue))
149-
15099
resultArray
151100

152-
let getBitmap<'a when 'a: struct> (clContext: ClContext) (workGroupSize: int) =
101+
let private getBitmap<'a when 'a: struct> (clContext: ClContext) (workGroupSize: int) =
153102

154103
let getPositions =
155104
<@ fun (ndRange: Range1D) length (vector: ClArray<'a option>) (positions: ClArray<int>) ->
@@ -202,7 +151,6 @@ module DenseVector =
202151
resultIndices.[index] <- gid
203152
| None -> () @>
204153

205-
206154
let kernel = clContext.Compile(getValuesAndIndices)
207155

208156
let getPositions = getBitmap clContext workGroupSize
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
namespace GraphBLAS.FSharp.Backend
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Common
5+
6+
module ElementwiseQuotes =
7+
let private elementWiseGeneralKernel writeOp =
8+
<@ fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
9+
10+
let gid = ndRange.GlobalID0
11+
12+
if gid < resultLength then
13+
(%writeOp) gid leftVector rightVector resultVector @>
14+
15+
let private elementWiseWrite opAdd =
16+
<@
17+
fun gid (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultArray: ClArray<'c option>) ->
18+
resultArray.[gid] <- (%opAdd) leftVector.[gid] rightVector.[gid]
19+
@>
20+
21+
let private elementWiseAtLeastOneWrite opAdd =
22+
<@
23+
fun gid (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultArray: ClArray<'c option>) ->
24+
match leftVector.[gid], rightVector.[gid] with
25+
| Some left, Some right -> resultArray.[gid] <- (%opAdd) (Both(left, right))
26+
| Some left, None -> resultArray.[gid] <- (%opAdd) (Left left)
27+
| None, Some right -> resultArray.[gid] <- (%opAdd) (Right right)
28+
| _ -> resultArray.[gid] <- None
29+
@>
30+
31+
let kernel opAdd = elementWiseGeneralKernel <| elementWiseWrite opAdd
32+
33+
let atLeastOneKernel opAdd = elementWiseGeneralKernel <| elementWiseAtLeastOneWrite opAdd
34+
35+
let private fillSubVectorGeneralKernel writeOp =
36+
<@
37+
fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (value: ClCell<'a>) (resultVector: ClArray<'c option>) ->
38+
39+
let gid = ndRange.GlobalID0
40+
41+
if gid < resultLength then
42+
(%writeOp) gid leftVector rightVector value.Value resultVector @>
43+
44+
let private fillSubVectorWrite opAdd =
45+
<@
46+
fun gid (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (value: 'a) (resultArray: ClArray<'c option>) ->
47+
resultArray.[gid] <- (%opAdd) leftVector.[gid] rightVector.[gid] value
48+
@>
49+
50+
let private fillSubVectorAtLeastOneWrite opAdd =
51+
<@
52+
fun gid (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (values: 'a) (resultArray: ClArray<'c option>) ->
53+
match leftVector.[gid], rightVector.[gid] with
54+
| Some left, Some right -> resultArray.[gid] <- (%opAdd) (Both(left, right)) values
55+
| Some left, None -> resultArray.[gid] <- (%opAdd) (Left left) values
56+
| None, Some right -> resultArray.[gid] <- (%opAdd) (Right right) values
57+
| _ -> resultArray.[gid] <- None
58+
@>
59+
60+
let fillSubVector maskOp = fillSubVectorGeneralKernel <| fillSubVectorWrite maskOp
61+
62+
let fillSubVectorAtLeastOne maskOp = fillSubVectorGeneralKernel <| fillSubVectorAtLeastOneWrite maskOp

src/GraphBLAS-sharp.Backend/Vector/Vector.fs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,15 @@ module Vector =
118118
SparseVector.fillSubVector clContext workGroupSize
119119

120120
let denseFillVector =
121-
DenseVector.fillSubVector clContext workGroupSize
121+
DenseVector.fillSubVector clContext StandardOperations.mask workGroupSize
122122

123123
let toCooVector =
124124
DenseVector.toSparse clContext workGroupSize
125125

126126
let toCooMask =
127127
DenseVector.toSparse clContext workGroupSize
128128

129-
fun (processor: MailboxProcessor<_>) (vector: ClVector<'a>) (maskVector: ClVector<'b>) (value: 'a) ->
129+
fun (processor: MailboxProcessor<_>) (vector: ClVector<'a>) (maskVector: ClVector<'b>) (value: ClCell<'a>) ->
130130
match vector, maskVector with
131131
| ClVectorSparse vector, ClVectorSparse mask ->
132132
ClVectorSparse
@@ -143,7 +143,7 @@ module Vector =
143143
<| cooFillVector value processor vector mask
144144
| ClVectorDense vector, ClVectorDense mask ->
145145
ClVectorDense
146-
<| denseFillVector value processor vector mask
146+
<| denseFillVector processor vector mask value
147147

148148
let reduce (clContext: ClContext) (workGroupSize: int) (opAdd: Expr<'a -> 'a -> 'a>) =
149149
let cooReduce =

0 commit comments

Comments
 (0)