起因:
最近在写一个导数据的程序,需要从几个老数据表中取出n多的数据,然后加以处理再添加到新数据库的对应表中。单步操作太慢了,这不正是多线程的用武之地吗?
对于每一种数据我都得写一套类似的代码,表意代码如下
1 | //从老数据库中获得一批老数据 |
2 | DataSet dsUser = OldDbAccess.GetOldUsers(minId); |
3 | //将dataset中的数据分成n份,放到sectionUsers的datatable中 |
4 | DataTable[] sectionUsers = new DataTable[threadCn]; |
5 | //声明threadCn个AutoResetEvent,让他们在线程执行完毕后发出执行完毕的信号量 |
6 | AutoResetEvent[] evts = new AutoResetEvent[threadCn]; |
7 | //初始化evts的值 |
8 | //将数据和AutoResetEvent放到一个数据中交给ThreadPool去处理,具体的处理方法略去了 |
由于老数据有n种,每一种的处理方法又不一样,所以我写了n个类似上面的处理步骤,这太累了吧,于是想重构,将上面的操作步骤中相同的地方提取出来。于是有了AsyncHelper静态类,这个静态类有几个公开的静态方法来做上面那些分多个线程处理大量数据的步骤中一样的过程。
第一个方法的签名应该是这样子的
1 | /// <summary> |
2 | /// 执行多线程操作任务 |
3 | /// </summary> |
4 | /// <param name="dataCollection">多线程操作的数据集合</param> |
5 | /// <param name="threadCn">分多少个线程来做</param> |
6 | /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> |
7 | public static void DoAsync(IList dataCollection, int threadCn, WaitCallback processItemMethod) |
大量的数据定义一个IList类型的类作为参数传递,要分多少个线程来处理这些数据用threadCn指定,需要注意的是WaitHandle.WaitAll方法决定了threadCn必须小于64,最后一个参数是处理大量数据中的一个数据的方法。方法的签名出来了,我在上面又做了几次重复的处理,写这个方法应该不是个问题。同样的我们也是需要把IList数据分成threadCn份,然后将每一份都交给ThreadPool来处理,这很简单。
01 | /// <summary> |
02 | /// 执行多线程操作任务 |
03 | /// </summary> |
04 | /// <param name="dataCollection">多线程操作的数据集合</param> |
05 | /// <param name="threadCn">分多少个线程来做</param> |
06 | /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> |
07 | public static void DoAsync(IList dataCollection, int threadCn, WaitCallback processItemMethod) |
08 | { |
09 | if (dataCollection == null ) throw new ArgumentNullException( "dataCollection" ); |
10 | |
11 | if (threadCn >= 64 || threadCn < 2) |
12 | { |
13 | throw new ArgumentOutOfRangeException( "threadCn" , "threadCn 参数必须在2和64之间" ); |
14 | } |
15 | |
16 | if (threadCn > dataCollection.Count) threadCn = dataCollection.Count; |
17 | |
18 | IList[] colls = new ArrayList[threadCn]; |
19 | |
20 | AutoResetEvent[] evts = new AutoResetEvent[threadCn]; |
21 | |
22 | for ( int i = 0; i < threadCn; i++) |
23 | { |
24 | colls[i] = new ArrayList(); |
25 | evts[i] = new AutoResetEvent( false ); |
26 | } |
27 | |
28 | for ( int i = 0; i < dataCollection.Count; i++) |
29 | { |
30 | object obj = dataCollection[i]; |
31 | int threadIndex = i % threadCn; |
32 | colls[threadIndex].Add(obj); |
33 | } |
34 | |
35 | for ( int i = 0; i < threadCn; i++) |
36 | { |
37 | ThreadPool.QueueUserWorkItem(DoPrivate, new object [] { |
38 | colls[i],processItemMethod,evts[i] |
39 | }); |
40 | } |
41 | |
42 | WaitHandle.WaitAll(evts); |
43 | } |
44 | private static void DoPrivate( object data) { |
45 | object [] datas = data as object []; |
46 | IList dataList = datas[0] as IList; |
47 | WaitCallback method = datas[1]; |
48 | AutoResetEvent evt = datas[2] as AutoResetEvent; |
49 | |
50 | foreach ( object item in dataList) |
51 | { |
52 | method(item); |
53 | } |
54 | evt.Set(); |
55 | } |
这个很容易实现,不过既然要做封装我们就不得不多考虑一些,我们的线程好比是一个一个的侦查兵,这次给他们的任务是抓一个敌人回来问问敌情,任务要求只抓一个敌人,也就是说如果某一个侦查兵抓到一个敌人之后要给其他战友发信息,告诉他们别忙了,任务已经完成了。这个该怎么办呢,办法总是要比问题多的。
WaitHandle类有WaitAny静态方法,上面侦察兵的例子不就是个WaitAny吗,主线程需要在接受到一个线程完成的信号后通知所有线程,“任务完成了,大家都回家吧”?大家如果有兴趣的话,可以给出自己的方案,我的方案明天放出来。明天一并要解决的还有取得多个执行操作的返回值问题。
我们需要解决WaitAny和取得异步执行的返回值的问题。地球人都知道Thread和ThreadPool接受的委托都是没有返回值的。要想取的返回值,我们就得自己动手了,我们需要构造一个AsyncContext类,由这个类来保存异步执行的状态以并存储返回值。
代码如下:
1 | using System; |
2 | using System.Collections.Generic; |
3 | using System.Text; |
4 | using System.Collections; |
5 | using System.Threading; |
6 | using System.Diagnostics; |
7 | |
8 | namespace AppUtility |
9 | { |
001 | public delegate object DoGetObjTask( object state); |
002 | public static class AsyncHelper |
003 | { |
004 | /// <summary> |
005 | /// 执行多线程操作任务 |
006 | /// </summary> |
007 | /// <param name="dataCollection">多线程操作的数据集合</param> |
008 | /// <param name="threadCn">分多少个线程来做</param> |
009 | /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> |
010 | public static void DoAsync(IList dataCollection, int threadCn, WaitCallback processItemMethod) |
011 | { |
012 | DoAsync(dataCollection, threadCn, processItemMethod, true ); |
013 | } |
014 | |
015 | |
016 | /// <summary> |
017 | /// 执行多线程操作任务 |
018 | /// </summary> |
019 | /// <param name="dataCollection">多线程操作的数据集合</param> |
020 | /// <param name="threadCn">分多少个线程来做</param> |
021 | /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> |
022 | /// <param name="needWaitAll">是否需要等待所有线程执行完毕才返回,为true时会等待所有线程执行完毕,否则则是在有一个线程执行完毕就返回</param> |
023 | public static void DoAsync(IList dataCollection, int threadCn, DoGetObjTask processItemMethod, bool needWaitAll, out Hashtable processResult) |
024 | { |
025 | DoAsyncPrivate(dataCollection, threadCn, null , processItemMethod, needWaitAll, true , out processResult); |
026 | } |
027 | |
028 | /// <summary> |
029 | /// 执行多线程操作任务 |
030 | /// </summary> |
031 | /// <param name="dataCollection">多线程操作的数据集合</param> |
032 | /// <param name="threadCn">分多少个线程来做</param> |
033 | /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> |
034 | /// <param name="needWaitAll">是否需要等待所有线程执行完毕才返回,为true时会等待所有线程执行完毕,否则则是在有一个线程执行完毕就返回</param> |
035 | public static void DoAsync(IList dataCollection, int threadCn, DoGetObjTask processItemMethod, out Hashtable processResult) |
036 | { |
037 | DoAsyncPrivate(dataCollection, threadCn, null , processItemMethod, true , true , out processResult); |
038 | } |
039 | |
040 | /// <summary> |
041 | /// 执行多线程操作任务 |
042 | /// </summary> |
043 | /// <param name="dataCollection">多线程操作的数据集合</param> |
044 | /// <param name="threadCn">分多少个线程来做</param> |
045 | /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> |
046 | /// <param name="needWaitAll">是否需要等待所有线程执行完毕才返回,为true时会等待所有线程执行完毕,否则则是在有一个线程执行完毕就返回</param> |
047 | public static void DoAsync(IList dataCollection, int threadCn, WaitCallback processItemMethod, bool needWaitAll) |
048 | { |
049 | Hashtable hash; |
050 | DoAsyncPrivate(dataCollection, threadCn, processItemMethod, null , needWaitAll, false , out hash); |
051 | } |
052 | |
053 | private static void DoAsyncPrivate(IList dataCollection, int threadCn, WaitCallback processItemMethod, DoGetObjTask getObjMethod, bool needWaitAll, bool hasReturnValue, out Hashtable processResult) |
054 | { |
055 | if (dataCollection == null ) throw new ArgumentNullException( "dataCollection" ); |
056 | |
057 | if (threadCn >= 64 || threadCn < 2) |
058 | { |
059 | throw new ArgumentOutOfRangeException( "threadCn" , "threadCn 参数必须在2和64之间" ); |
060 | } |
061 | |
062 | if (threadCn > dataCollection.Count) threadCn = dataCollection.Count; |
063 | |
064 | IList[] colls = new ArrayList[threadCn]; |
065 | |
066 | DataWithStateList dataWithStates = new DataWithStateList(); |
067 | AutoResetEvent[] evts = new AutoResetEvent[threadCn]; |
068 | |
069 | for ( int i = 0; i < threadCn; i++) |
070 | { |
071 | colls[i] = new ArrayList(); |
072 | evts[i] = new AutoResetEvent( false ); |
073 | } |
074 | |
075 | for ( int i = 0; i < dataCollection.Count; i++) |
076 | { |
077 | object obj = dataCollection[i]; |
078 | int threadIndex = i % threadCn; |
079 | colls[threadIndex].Add(obj); |
080 | dataWithStates.Add( new DataWithState(obj, ProcessState.WaitForProcess)); |
081 | } |
082 | |
083 | AsyncContext context = AsyncContext.GetContext(threadCn, dataWithStates, needWaitAll, hasReturnValue, processItemMethod, getObjMethod); |
084 | |
085 | for ( int i = 0; i < threadCn; i++) |
086 | { |
087 | ThreadPool.QueueUserWorkItem(DoPrivate, new object [] { |
088 | colls[i],context,evts[i] |
089 | }); |
090 | } |
091 | |
092 | if (needWaitAll) |
093 | { |
094 | WaitHandle.WaitAll(evts); |
095 | } |
096 | else |
097 | { |
098 | WaitHandle.WaitAny(evts); |
099 | context.SetBreakSignal(); |
100 | } |
101 | processResult = context.ProcessResult; |
102 | } |
103 | |
104 | private class AsyncContext |
105 | { |
106 | static public AsyncContext GetContext( |
107 | int threadCn, |
108 | DataWithStateList dataWithStates, |
109 | bool needWaitAll, |
110 | bool hasReturnValue, |
111 | WaitCallback processItemMethod, |
112 | DoGetObjTask hasReturnValueMethod |
113 | ) |
114 | { |
115 | AsyncContext context = new AsyncContext(); |
116 | context.ThreadCount = threadCn; |
117 | context.DataWithStates = dataWithStates; |
118 | context.NeedWaitAll = needWaitAll; |
119 | if (hasReturnValue) |
120 | { |
121 | Hashtable processResult = Hashtable.Synchronized( new Hashtable()); |
122 | context.ProcessResult = processResult; |
123 | context.HasReturnValueMethod = hasReturnValueMethod; |
124 | } |
125 | else |
126 | { |
127 | context.VoidMethod = processItemMethod; |
128 | } |
129 | context.HasReturnValue = hasReturnValue; |
130 | return context; |
131 | } |
132 | |
133 | internal int ThreadCount; |
134 | |
135 | internal DataWithStateList DataWithStates; |
136 | |
137 | internal bool NeedWaitAll; |
138 | |
139 | internal bool HasReturnValue; |
140 | |
141 | internal WaitCallback VoidMethod; |
142 | |
143 | internal DoGetObjTask HasReturnValueMethod; |
144 | |
145 | private bool _breakSignal; |
146 | |
147 | private Hashtable _processResult; |
148 | |
149 | internal Hashtable ProcessResult |
150 | { |
151 | get { return _processResult; } |
152 | set { _processResult = value; } |
153 | } |
154 | |
155 | internal void SetReturnValue( object obj, object result) |
156 | { |
157 | lock (_processResult.SyncRoot) |
158 | { |
159 | _processResult[obj] = result; |
160 | } |
161 | } |
162 | |
163 | internal void SetBreakSignal() |
164 | { |
165 | if (NeedWaitAll) throw new NotSupportedException( "设定为NeedWaitAll时不可设置BreakSignal" ); |
166 | |
167 | _breakSignal = true ; |
168 | } |
169 | |
170 | internal bool NeedBreak |
171 | { |
172 | get |
173 | { |
174 | return !NeedWaitAll && _breakSignal; |
175 | } |
176 | } |
177 | |
178 | internal void Exec( object obj) |
179 | { |
180 | if (HasReturnValue) |
181 | { |
182 | SetReturnValue(obj, HasReturnValueMethod(obj)); |
183 | } |
184 | else |
185 | { |
186 | VoidMethod(obj); |
187 | } |
188 | DataWithStates.SetState(obj, ProcessState.Processed); |
189 | } |
190 | } |
191 | |
192 | private enum ProcessState : byte |
193 | { |
194 | WaitForProcess = 0, |
195 | Processing = 1, |
196 | Processed = 2 |
197 | } |
198 | |
199 | private class DataWithStateList : List<DataWithState> |
200 | { |
201 | public void SetState( object obj, ProcessState state) |
202 | { |
203 | lock (((ICollection) this ).SyncRoot) |
204 | { |
205 | DataWithState dws = this .Find( delegate (DataWithState i) { return Object.Equals(i.Data, obj); }); |
206 | |
207 | if (dws != null ) |
208 | { |
209 | dws.State = state; |
210 | } |
211 | } |
212 | } |
213 | |
214 | public ProcessState GetState( object obj) |
215 | { |
216 | lock (((ICollection) this ).SyncRoot) |
217 | { |
218 | DataWithState dws = this .Find( delegate (DataWithState i) { return Object.Equals(i.Data, obj); }); |
219 | return dws.State; |
220 | } |
221 | } |
222 | |
223 | private int GetCount(ProcessState state) |
224 | { |
225 | List<DataWithState> datas = this .FindAll( delegate (DataWithState i) { return i.State == state; }); |
226 | if (datas == null ) return 0; |
227 | return datas.Count; |
228 | } |
229 | |
230 | public int WaitForDataCount |
231 | { |
232 | get |
233 | { |
234 | return GetCount(ProcessState.WaitForProcess); |
235 | } |
236 | } |
237 | |
238 | internal object GetWaitForObject() |
239 | { |
240 | lock (((ICollection) this ).SyncRoot) |
241 | { |
242 | DataWithState dws = this .Find( delegate (DataWithState i) { return i.State == ProcessState.WaitForProcess; }); |
243 | if (dws == null ) return null ; |
244 | dws.State = ProcessState.Processing; |
245 | return dws.Data; |
246 | } |
247 | } |
248 | |
249 | internal bool IsWaitForData( object obj, bool setState) |
250 | { |
251 | lock (((ICollection) this ).SyncRoot) |
252 | { |
253 | DataWithState dws = this .Find( delegate (DataWithState i) { return i.State == ProcessState.WaitForProcess; }); |
254 | |
255 | if (setState && dws != null ) dws.State = ProcessState.Processing; |
256 | |
257 | return dws != null ; |
258 | } |
259 | } |
260 | } |
261 | |
262 | private class DataWithState |
263 | { |
264 | public readonly object Data; |
265 | public ProcessState State; |
266 | |
267 | public DataWithState( object data, ProcessState state) |
268 | { |
269 | Data = data; |
270 | State = state; |
271 | } |
272 | } |
273 | |
274 | private static int _threadNo = 0; |
275 | |
276 | private static void DoPrivate( object state) |
277 | { |
278 | object [] objs = state as object []; |
279 | |
280 | IList datas = objs[0] as IList; |
281 | AsyncContext context = objs[1] as AsyncContext; |
282 | AutoResetEvent evt = objs[2] as AutoResetEvent; |
283 | |
284 | DataWithStateList objStates = context.DataWithStates; |
285 | |
286 | #if DEBUG |
287 | Thread.CurrentThread.Name = "Thread " + _threadNo; |
288 | |
289 | Interlocked.Increment( ref _threadNo); |
290 | string threadName = Thread.CurrentThread.Name + "[" + Thread.CurrentThread.ManagedThreadId + "]" ; |
291 | Trace.WriteLine( "线程ID:" + threadName); |
292 | #endif |
293 | if (datas != null ) |
294 | { |
295 | for ( int i = 0; i < datas.Count; i++) |
296 | { |
297 | if (context.NeedBreak) |
298 | { |
299 | #if DEBUG |
300 | Trace.WriteLine( "线程" + threadName + "未执行完跳出" ); |
301 | #endif |
302 | break ; |
303 | } |
304 | object obj = datas[i]; |
305 | if (objStates.IsWaitForData(obj, true )) |
306 | { |
307 | if (context.NeedBreak) |
308 | { |
309 | #if DEBUG |
310 | Trace.WriteLine( "线程" + threadName + "未执行完跳出" ); |
311 | #endif |
312 | break ; |
313 | } |
314 | |
315 | context.Exec(obj); |
316 | |
317 | #if DEBUG |
318 | Trace.WriteLine( string .Format( "线程{0}处理{1}" , threadName, obj)); |
319 | #endif |
320 | } |
321 | } |
322 | } |
323 | |
324 | if (context.NeedWaitAll) |
325 | { |
326 | //如果执行完当前进程的数据,还要查看剩下多少没有做,如果还剩下超过ThreadCount个没有做 |
327 | while (objStates.WaitForDataCount > context.ThreadCount) |
328 | { |
329 | if (context.NeedBreak) break ; |
330 | |
331 | object obj = objStates.GetWaitForObject(); |
332 | if (obj != null && objStates.IsWaitForData(obj, false )) |
333 | { |
334 | if (context.NeedBreak) |
335 | { |
336 | #if DEBUG |
337 | Trace.WriteLine( "线程" + threadName + "未执行完跳出" ); |
338 | #endif |
339 | break ; |
340 | } |
341 | |
342 | context.Exec(obj); |
343 | |
344 | #if DEBUG |
345 | Trace.WriteLine( string .Format( "线程{0}执行另一个进程的数据{1}" , threadName, obj)); |
346 | #endif |
347 | } |
348 | } |
349 | } |
350 | |
351 | evt.Set(); |
352 | } |
353 | |
354 | |
355 | } |
356 | } |
如何使用AsyncHelper类,请看下面的测试代码:
using System;
using System.Collections.Generic;using System.Text;using System.Diagnostics;using AppUtility;using System.IO;using System.Collections;using System.Threading;namespace ConsoleApplication2{ class Program { static void Main(string[] args) { Stopwatch sw = new Stopwatch(); sw.Start(); /* List<string> testFiles = new List<string>(); for (int i = 0; i < 100; i++) { testFiles.Add("D:\\test\\async\\file_" + i.ToString() + ".log"); } AsyncHelper.DoAsync(testFiles, 10, WriteFile); Console.WriteLine("异步写耗时"+sw.ElapsedMilliseconds + "ms"); */ List<string> testFiles = new List<string>(); for (int i = 0; i < 200; i++) { testFiles.Add("D:\\test\\async\\file_" + i.ToString() + ".log"); } Hashtable result; AsyncHelper.DoAsync(testFiles, 20, WriteFileAndReturnRowCount,false,out result); Console.WriteLine("异步写耗时" + sw.ElapsedMilliseconds + "ms"); Thread.Sleep(10); if (result != null) { foreach (object key in result.Keys) { Console.WriteLine("{0}={1}", key,result[key]); } } sw.Reset(); sw.Start(); for (int i = 0; i < 200; i++) { WriteFile("D:\\test\\sync\\file_" + i.ToString() + ".log"); } Console.WriteLine("同步写耗时" + sw.ElapsedMilliseconds + "ms"); Console.Read(); } static void WriteFile(object objFilePath) { string filePath = (string)objFilePath; string dir = Path.GetDirectoryName(filePath); if (!Directory.Exists(dir)) { Directory.CreateDirectory(dir); } //Random r = new Random(DateTime.Now.Minute); int rowCn = 10000; using (StreamWriter writer = new StreamWriter(filePath, false, Encoding.Default)) { for (int i = 0; i < rowCn; i++) writer.WriteLine(Guid.NewGuid()); } } static object WriteFileAndReturnRowCount(object objFilePath) { string filePath = (string)objFilePath; string dir = Path.GetDirectoryName(filePath); if (!Directory.Exists(dir)) { Directory.CreateDirectory(dir); } //Random r = new Random(DateTime.Now.Minute); int rowCn = 10000; using (StreamWriter writer = new StreamWriter(filePath, false, Encoding.Default)) { for (int i = 0; i < rowCn ; i++) writer.WriteLine(Guid.NewGuid()); } return DateTime.Now.ToLongTimeString(); } }}