11namespace GraphBLAS.FSharp.Backend.Algorithms
22
3+ open GraphBLAS.FSharp
34open GraphBLAS.FSharp .Backend
45open Brahma.FSharp
56open GraphBLAS.FSharp .Objects
@@ -86,49 +87,52 @@ module internal PageRank =
8687 let copy =
8788 GraphBLAS.FSharp.ClArray.copy clContext workGroupSize
8889
89- let transpose =
90+ let transposeInPlace =
9091 Matrix.CSR.Matrix.transposeInPlace clContext workGroupSize
9192
9293 let multiply = clContext.Compile multiply
9394
94- fun ( queue : MailboxProcessor < Msg >) ( matrix : ClMatrix.CSR < float32 >) ->
95+ fun ( queue : MailboxProcessor < Msg >) ( matrix : ClMatrix < float32 >) ->
96+
97+ match matrix with
98+ | ClMatrix.CSR matrix ->
9599
96- let outDegree = countOutDegree queue matrix
100+ let outDegree = countOutDegree queue matrix
97101
98- let resultValues =
99- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, matrix.Values.Length)
102+ let resultValues =
103+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, matrix.Values.Length)
100104
101- let kernel = multiply.GetKernel()
105+ let kernel = multiply.GetKernel()
102106
103- let ndRange =
104- Range1D.CreateValid( matrix.RowCount * workGroupSize, workGroupSize)
107+ let ndRange =
108+ Range1D.CreateValid( matrix.RowCount * workGroupSize, workGroupSize)
105109
106- queue.Post(
107- Msg.MsgSetArguments
108- ( fun () ->
109- kernel.KernelFunc
110- ndRange
111- matrix.RowCount
112- matrix.RowPointers
113- matrix.Values
114- outDegree
115- resultValues)
116- )
110+ queue.Post(
111+ Msg.MsgSetArguments
112+ ( fun () ->
113+ kernel.KernelFunc
114+ ndRange
115+ matrix.RowCount
116+ matrix.RowPointers
117+ matrix.Values
118+ outDegree
119+ resultValues)
120+ )
117121
118- queue.Post( Msg.CreateRunMsg<_, _> kernel)
122+ queue.Post( Msg.CreateRunMsg<_, _> kernel)
119123
120- outDegree.Free queue
124+ outDegree.Free queue
121125
122- let newMatrix =
123- { Context = clContext
124- RowCount = matrix.RowCount
125- ColumnCount = matrix.ColumnCount
126- RowPointers = copy queue DeviceOnly matrix.RowPointers
127- Columns = copy queue DeviceOnly matrix.Columns
128- Values = resultValues }
126+ let newMatrix =
127+ { Context = clContext
128+ RowCount = matrix.RowCount
129+ ColumnCount = matrix.ColumnCount
130+ RowPointers = copy queue DeviceOnly matrix.RowPointers
131+ Columns = copy queue DeviceOnly matrix.Columns
132+ Values = resultValues }
129133
130- let transposed = transpose queue DeviceOnly newMatrix
131- transposed
134+ transposeInPlace queue DeviceOnly newMatrix |> ClMatrix.CSR
135+ | _ -> failwith " Not implemented "
132136
133137 let run ( clContext : ClContext ) workGroupSize =
134138
@@ -140,34 +144,34 @@ module internal PageRank =
140144 let mul = ArithmeticOperations.float32MulOption
141145
142146 let spMVTo =
143- Operations.SpMV.runTo plus mul clContext workGroupSize
147+ Operations.SpMVInPlace plus mul clContext workGroupSize
144148
145149 let addToResult =
146- Vector.map2InPlace plus clContext workGroupSize
150+ GraphBLAS.FSharp. Vector.map2InPlace plus clContext workGroupSize
147151
148152 let subtractAndSquare =
149- Vector.map2InPlace minusAndSquare clContext workGroupSize
153+ GraphBLAS.FSharp. Vector.map2To minusAndSquare clContext workGroupSize
150154
151155 let reduce =
152- Vector.reduce <@ (+) @> clContext workGroupSize
156+ GraphBLAS.FSharp. Vector.reduce <@ (+) @> clContext workGroupSize
153157
154158 let create =
155- GraphBLAS.FSharp.ClArray .create clContext workGroupSize
159+ GraphBLAS.FSharp.Vector .create clContext workGroupSize
156160
157- fun ( queue : MailboxProcessor < Msg >) ( matrix : ClMatrix.CSR < float32 >) ->
161+ fun ( queue : MailboxProcessor < Msg >) ( matrix : ClMatrix < float32 >) ->
158162
159163 let vertexCount = matrix.RowCount
160164
161165 //None is 0
162- let mutable rank = create queue DeviceOnly vertexCount None
166+ let mutable rank = create queue DeviceOnly vertexCount Dense None
163167
164168 let mutable prevRank =
165- create queue DeviceOnly vertexCount ( Some( 1.0 f / ( float32 vertexCount)))
169+ create queue DeviceOnly vertexCount Dense ( Some( 1.0 f / ( float32 vertexCount)))
166170
167- let mutable errors = create queue DeviceOnly vertexCount None
171+ let mutable errors = create queue DeviceOnly vertexCount Dense None
168172
169173 let addition =
170- create queue DeviceOnly vertexCount ( Some(( 1.0 f - alpha) / ( float32 vertexCount)))
174+ create queue DeviceOnly vertexCount Dense ( Some(( 1.0 f - alpha) / ( float32 vertexCount)))
171175
172176 let mutable error = accuracy + 0.1 f
173177
@@ -178,7 +182,7 @@ module internal PageRank =
178182
179183 // rank = matrix*rank + (1 - alpha)/N
180184 spMVTo queue matrix prevRank rank
181- addToResult queue rank addition rank
185+ addToResult queue rank addition
182186
183187 // error
184188 subtractAndSquare queue rank prevRank errors
@@ -189,8 +193,8 @@ module internal PageRank =
189193 rank <- prevRank
190194 prevRank <- temp
191195
192- prevRank.Free queue
193- errors.Free queue
194- addition.Free queue
196+ prevRank.Dispose queue
197+ errors.Dispose queue
198+ addition.Dispose queue
195199
196200 rank
0 commit comments