Skip to content

Commit a5a3d8c

Browse files
committed
add: atLeastOneToNormalForm fun
1 parent 7b2e627 commit a5a3d8c

9 files changed

Lines changed: 243 additions & 244 deletions

File tree

src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,44 +38,15 @@ module DenseVector =
3838

3939
resultVector
4040

41-
let elementWiseAtLeastOne<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct>
42-
(clContext: ClContext)
43-
(opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>)
44-
(workGroupSize: int)
45-
=
46-
47-
let kernel = clContext.Compile(ElementwiseConstructor.atLeastOneKernel opAdd)
48-
49-
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) ->
50-
51-
let resultVector =
52-
clContext.CreateClArray(
53-
leftVector.Length,
54-
hostAccessMode = HostAccessMode.NotAccessible,
55-
deviceAccessMode = DeviceAccessMode.ReadWrite,
56-
allocationMode = AllocationMode.Default
57-
)
58-
59-
let ndRange =
60-
Range1D.CreateValid(leftVector.Length, workGroupSize)
61-
62-
let kernel = kernel.GetKernel()
63-
64-
processor.Post(
65-
Msg.MsgSetArguments
66-
(fun () -> kernel.KernelFunc ndRange leftVector.Length leftVector rightVector resultVector)
67-
)
68-
69-
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
70-
71-
resultVector
41+
let elementWiseAtLeastOne clContext op workGroupSize =
42+
elementWise clContext (ElementwiseConstructor.atLeastOneToNormalForm op) workGroupSize
7243

7344
let fillSubVector<'a, 'b when 'a: struct and 'b: struct>
7445
(clContext: ClContext)
7546
(maskOp: Expr<'a option -> 'b option -> 'a -> 'a option>)
7647
(workGroupSize: int) =
7748

78-
let kernel = clContext.Compile(ElementwiseConstructor.fillSubVector maskOp)
49+
let kernel = clContext.Compile(ElementwiseConstructor.fillSubVectorKernel maskOp)
7950

8051
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (maskVector: ClArray<'b option>) (value: ClCell<'a>) ->
8152
let resultArray =
@@ -98,6 +69,9 @@ module DenseVector =
9869

9970
resultArray
10071

72+
let fillSubVectorAtLeasOne clContext opAdd workGroupSize =
73+
fillSubVector clContext (ElementwiseConstructor.fillSubVectorAtLeastOneToNormalForm opAdd) workGroupSize
74+
10175
let private getBitmap<'a when 'a: struct> (clContext: ClContext) (workGroupSize: int) =
10276

10377
let getPositions =

src/GraphBLAS-sharp.Backend/Vector/DenseVector/ElementwiseConstructor.fs

Lines changed: 87 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,59 +4,105 @@ open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend.Common
55

66
module ElementwiseConstructor =
7-
let private elementWiseGeneralKernel writeOp =
8-
<@ fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
7+
// let private elementWiseGeneralKernel writeOp =
8+
// <@ fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
9+
//
10+
// let gid = ndRange.GlobalID0
11+
//
12+
// if gid < resultLength then
13+
// resultVector[gid] <- (%writeOp) leftVector.[gid] rightVector.[gid] @>
14+
//
15+
// let private elementWiseWrite opAdd =
16+
// <@
17+
// fun (leftItem: 'a option) (rightItem: 'b option) ->
18+
// (%opAdd) leftItem rightItem
19+
// @>
20+
//
21+
// let private elementWiseAtLeastOneWrite opAdd =
22+
// <@
23+
// fun (leftItem: 'a option) (rightItem: 'b option) ->
24+
// match leftItem, rightItem with
25+
// | Some left, Some right -> (%opAdd) (Both(left, right))
26+
// | Some left, None -> (%opAdd) (Left left)
27+
// | None, Some right -> (%opAdd) (Right right)
28+
// | _ -> None
29+
// @>
930

10-
let gid = ndRange.GlobalID0
31+
// let kernel opAdd = elementWiseGeneralKernel <| elementWiseWrite opAdd
32+
//
33+
// let atLeastOneKernel opAdd = elementWiseGeneralKernel <| elementWiseAtLeastOneWrite opAdd
1134

12-
if gid < resultLength then
13-
(%writeOp) gid leftVector rightVector resultVector @>
14-
15-
let private elementWiseWrite opAdd =
35+
let kernel opAdd =
1636
<@
17-
fun gid (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultArray: ClArray<'c option>) ->
18-
resultArray.[gid] <- (%opAdd) leftVector.[gid] rightVector.[gid]
19-
@>
37+
fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
2038

21-
let private elementWiseAtLeastOneWrite opAdd =
22-
<@
23-
fun gid (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultArray: ClArray<'c option>) ->
24-
match leftVector.[gid], rightVector.[gid] with
25-
| Some left, Some right -> resultArray.[gid] <- (%opAdd) (Both(left, right))
26-
| Some left, None -> resultArray.[gid] <- (%opAdd) (Left left)
27-
| None, Some right -> resultArray.[gid] <- (%opAdd) (Right right)
28-
| _ -> resultArray.[gid] <- None
29-
@>
39+
let gid = ndRange.GlobalID0
3040

31-
let kernel opAdd = elementWiseGeneralKernel <| elementWiseWrite opAdd
41+
if gid < resultLength then
42+
resultVector[gid] <- (%opAdd) leftVector.[gid] rightVector.[gid]
43+
@>
3244

33-
let atLeastOneKernel opAdd = elementWiseGeneralKernel <| elementWiseAtLeastOneWrite opAdd
45+
// let private fillSubVectorGeneralKernel writeOp =
46+
// <@
47+
// fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (value: ClCell<'a>) (resultVector: ClArray<'c option>) ->
48+
//
49+
// let gid = ndRange.GlobalID0
50+
//
51+
// if gid < resultLength then
52+
// resultVector[gid] <- (%writeOp) leftVector.[gid] rightVector.[gid] value.Value @>
53+
//
54+
// let private fillSubVectorWrite (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
55+
// <@
56+
// fun (leftItem: 'a option) (rightItem: 'b option) (value: 'a) ->
57+
// (%opAdd) leftItem rightItem value
58+
// @>
59+
//
60+
// let private fillSubVectorAtLeastOneWrite (opAdd: Expr<AtLeastOne<'a, 'b> -> 'a-> 'a option>) =
61+
// <@
62+
// fun (leftItem: 'a option) (rightItem: 'b option) (values: 'a) ->
63+
// match leftItem, rightItem with
64+
// | Some left, Some right -> (%opAdd) (Both(left, right)) values
65+
// | Some left, None -> (%opAdd) (Left left) values
66+
// | None, Some right -> (%opAdd) (Right right) values
67+
// | _ -> None
68+
// @>
3469

35-
let private fillSubVectorGeneralKernel writeOp =
36-
<@
37-
fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (value: ClCell<'a>) (resultVector: ClArray<'c option>) ->
70+
// let fillSubVector maskOp = fillSubVectorGeneralKernel <| fillSubVectorWrite maskOp
71+
//
72+
// let fillSubVectorAtLeastOne maskOp = fillSubVectorGeneralKernel <| fillSubVectorAtLeastOneWrite maskOp
73+
let fillSubVectorKernel opAdd =
74+
<@
75+
fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (value: ClCell<'a>) (resultVector: ClArray<'c option>) ->
3876

39-
let gid = ndRange.GlobalID0
77+
let gid = ndRange.GlobalID0
4078

41-
if gid < resultLength then
42-
(%writeOp) gid leftVector rightVector value.Value resultVector @>
79+
if gid < resultLength then
80+
resultVector[gid] <- (%opAdd) leftVector.[gid] rightVector.[gid] value.Value @>
4381

44-
let private fillSubVectorWrite opAdd =
82+
let atLeastOneToNormalForm op =
4583
<@
46-
fun gid (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (value: 'a) (resultArray: ClArray<'c option>) ->
47-
resultArray.[gid] <- (%opAdd) leftVector.[gid] rightVector.[gid] value
84+
fun (leftItem: 'a option) (rightItem: 'b option) ->
85+
match leftItem, rightItem with
86+
| Some left, Some right ->
87+
(%op) (Both(left, right))
88+
| None, Some right ->
89+
(%op) (Right right)
90+
| Some left, None ->
91+
(%op) (Left left)
92+
| None, None ->
93+
None
4894
@>
4995

50-
let private fillSubVectorAtLeastOneWrite opAdd =
96+
let fillSubVectorAtLeastOneToNormalForm op =
5197
<@
52-
fun gid (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (values: 'a) (resultArray: ClArray<'c option>) ->
53-
match leftVector.[gid], rightVector.[gid] with
54-
| Some left, Some right -> resultArray.[gid] <- (%opAdd) (Both(left, right)) values
55-
| Some left, None -> resultArray.[gid] <- (%opAdd) (Left left) values
56-
| None, Some right -> resultArray.[gid] <- (%opAdd) (Right right) values
57-
| _ -> resultArray.[gid] <- None
98+
fun (leftItem: 'a option) (rightItem: 'b option) (value: 'a) ->
99+
match leftItem, rightItem with
100+
| Some left, Some right ->
101+
(%op) (Both(left, right)) value
102+
| None, Some right ->
103+
(%op) (Right right) value
104+
| Some left, None ->
105+
(%op) (Left left) value
106+
| None, None ->
107+
None
58108
@>
59-
60-
let fillSubVector maskOp = fillSubVectorGeneralKernel <| fillSubVectorWrite maskOp
61-
62-
let fillSubVectorAtLeastOne maskOp = fillSubVectorGeneralKernel <| fillSubVectorAtLeastOneWrite maskOp

src/GraphBLAS-sharp.Backend/Vector/SparseVector/ElementwiseConstructor.fs

Lines changed: 40 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -119,82 +119,6 @@ module ElementwiseConstructor =
119119
firstResultValues.[i] <- firstValuesBuffer.[beginIdx + boundaryX]
120120
isLeftBitMap.[i] <- 1 @>
121121

122-
module FillSubVectorRead =
123-
let both (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
124-
<@
125-
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: 'a) ->
126-
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1]) value
127-
@>
128-
129-
let left (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
130-
<@
131-
fun gid (leftValues: ClArray<'a>) (value: 'a) ->
132-
(%opAdd) (Some leftValues.[gid]) None value
133-
@>
134-
135-
let right (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
136-
<@
137-
fun gid (rightValues: ClArray<'b>) (value: 'a) ->
138-
(%opAdd) None (Some rightValues.[gid + 1]) value
139-
@>
140-
141-
module FillSubVectorAtLeasOneRead =
142-
let both (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
143-
<@
144-
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: 'a) ->
145-
(%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1])) value
146-
@>
147-
148-
let left (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
149-
<@
150-
fun gid (leftValues: ClArray<'a>) (value: 'a) ->
151-
(%opAdd) (Left(leftValues.[gid])) value
152-
@>
153-
154-
let right (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
155-
<@
156-
fun gid (rightValues: ClArray<'b>) (value: 'a) ->
157-
(%opAdd) (Right(rightValues.[gid])) value
158-
@>
159-
160-
module ElementWiseRead =
161-
let both (opAdd: Expr<'a option -> 'b option -> 'c option>) =
162-
<@
163-
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) ->
164-
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1])
165-
@>
166-
167-
let left (opAdd: Expr<'a option -> 'b option -> 'c option>) =
168-
<@
169-
fun gid (leftValues: ClArray<'a>)->
170-
(%opAdd) (Some leftValues.[gid]) None
171-
@>
172-
173-
let right (opAdd: Expr<'a option -> 'b option -> 'c option>) =
174-
<@
175-
fun gid (rightValues: ClArray<'b>) ->
176-
(%opAdd) None (Some rightValues.[gid + 1])
177-
@>
178-
179-
module ElementWiseAtLeasOneRead =
180-
let both (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
181-
<@
182-
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) ->
183-
(%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1]))
184-
@>
185-
186-
let left (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
187-
<@
188-
fun gid (leftValues: ClArray<'a>) ->
189-
(%opAdd) (Left(leftValues.[gid]))
190-
@>
191-
192-
let right (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
193-
<@
194-
fun gid (rightValues: ClArray<'b>) ->
195-
(%opAdd) (Right(rightValues.[gid]))
196-
@>
197-
198122
let private both<'c> =
199123
<@ fun index (result: 'c option) (rawPositionsBuffer: ClArray<int>) (allValuesBuffer: ClArray<'c>) ->
200124
rawPositionsBuffer.[index] <- 0
@@ -220,60 +144,70 @@ module ElementwiseConstructor =
220144
rawPositionsBuffer.[index] <- 1
221145
| None -> rawPositionsBuffer.[index] <- 0 @>
222146

223-
let private preparePositionsGeneral
224-
bothRead
225-
leftRead
226-
rightRead
227-
=
228-
229-
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
147+
let prepareFillVector opAdd =
148+
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: ClCell<'a>) (isLeft: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
230149

231150
let gid = ndRange.GlobalID0
232151

152+
let value = value.Value
153+
233154
if gid < length - 1
234155
&& allIndices.[gid] = allIndices.[gid + 1] then
235-
let (result: 'c option) = (%bothRead) gid leftValues rightValues
156+
let result = (%opAdd) (Some leftValues[gid]) (Some rightValues[gid + 1]) value
236157

237158
(%both) gid result positions allValues
238159
elif (gid < length
239160
&& gid > 0
240161
&& allIndices.[gid - 1] <> allIndices.[gid])
241-
|| gid = 0 then
242-
243-
let leftResult = (%leftRead) gid leftValues
244-
let rightResult = (%rightRead) gid rightValues
162+
|| gid = 0 then
163+
let leftResult = (%opAdd) (Some leftValues.[gid]) None value
164+
let rightResult = (%opAdd) None (Some rightValues.[gid]) value
245165

246166
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
247167

248-
let private prepareFillVectorGeneral bothRead leftRead rightRead =
249-
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: ClCell<'a>) (isLeft: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
168+
let preparePositions opAdd =
169+
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
250170

251171
let gid = ndRange.GlobalID0
252172

253-
let value = value.Value
254-
255173
if gid < length - 1
256174
&& allIndices.[gid] = allIndices.[gid + 1] then
257-
let (result: 'a option) = (%bothRead) gid leftValues rightValues value
175+
let result = (%opAdd) (Some leftValues[gid]) (Some rightValues[gid + 1])
258176

259177
(%both) gid result positions allValues
260178
elif (gid < length
261179
&& gid > 0
262180
&& allIndices.[gid - 1] <> allIndices.[gid])
263181
|| gid = 0 then
264-
let leftResult = (%leftRead) gid leftValues value
265-
let rightResult = (%rightRead) gid rightValues value
182+
let leftResult = (%opAdd) (Some leftValues.[gid]) None
183+
let rightResult = (%opAdd) None (Some rightValues.[gid])
266184

267185
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
268186

269-
let preparePositions opAdd =
270-
preparePositionsGeneral (ElementWiseRead.both opAdd) (ElementWiseRead.left opAdd) (ElementWiseRead.right opAdd)
271-
272-
let preparePositionsAtLeastOne opAdd =
273-
preparePositionsGeneral (ElementWiseAtLeasOneRead.both opAdd) (ElementWiseAtLeasOneRead.left opAdd) (ElementWiseAtLeasOneRead.right opAdd)
274-
275-
let prepareFillVector opAdd =
276-
prepareFillVectorGeneral (FillSubVectorRead.both opAdd) (FillSubVectorRead.left opAdd) (FillSubVectorRead.right opAdd)
277-
278-
let prepareFillVectorAtLeastOne opAdd =
279-
prepareFillVectorGeneral (FillSubVectorAtLeasOneRead.both opAdd) (FillSubVectorAtLeasOneRead.left opAdd) (FillSubVectorAtLeasOneRead.right opAdd)
187+
let atLeastOneToNormalForm (op: Expr<AtLeastOne<'a, 'b> -> 'c option>) =
188+
<@
189+
fun (leftItem: 'a option) (rightItem: 'b option) ->
190+
match leftItem, rightItem with
191+
| Some left, Some right ->
192+
(%op) (Both(left, right))
193+
| None, Some right ->
194+
(%op) (Right right)
195+
| Some left, None ->
196+
(%op) (Left left)
197+
| None, None ->
198+
None
199+
@>
200+
201+
let fillSubVectorAtLeastOneToNormalForm (op: Expr<AtLeastOne<'a, 'b> -> 'a -> 'a option>) =
202+
<@
203+
fun (leftItem: 'a option) (rightItem: 'b option) (value: 'a) ->
204+
match leftItem, rightItem with
205+
| Some left, Some right ->
206+
(%op) (Both(left, right)) value
207+
| None, Some right ->
208+
(%op) (Right right) value
209+
| Some left, None ->
210+
(%op) (Left left) value
211+
| None, None ->
212+
None
213+
@>

0 commit comments

Comments
 (0)