diff --git a/mcs/class/System/System.IO.Compression/DeflateStream.cs b/mcs/class/System/System.IO.Compression/DeflateStream.cs index 22aaa6eae180..f4fc095d8dcc 100644 --- a/mcs/class/System/System.IO.Compression/DeflateStream.cs +++ b/mcs/class/System/System.IO.Compression/DeflateStream.cs @@ -34,6 +34,7 @@ using System.IO; using System.Runtime.InteropServices; using System.Runtime.Remoting.Messaging; +using Microsoft.Win32.SafeHandles; #if MONOTOUCH using MonoTouch; @@ -316,8 +317,7 @@ class DeflateStreamNative UnmanagedReadOrWrite feeder; // This will be passed to unmanaged code and used there Stream base_stream; - IntPtr z_stream; - GCHandle data; + SafeNativeZStreamHandle z_stream; bool disposed; byte [] io_buffer; @@ -328,10 +328,9 @@ private DeflateStreamNative () public static DeflateStreamNative Create (Stream compressedStream, CompressionMode mode, bool gzip) { var dsn = new DeflateStreamNative (); - dsn.data = GCHandle.Alloc (dsn); dsn.feeder = mode == CompressionMode.Compress ? new UnmanagedReadOrWrite (UnmanagedWrite) : new UnmanagedReadOrWrite (UnmanagedRead); - dsn.z_stream = CreateZStream (mode, gzip, dsn.feeder, GCHandle.ToIntPtr (dsn.data)); - if (dsn.z_stream == IntPtr.Zero) { + dsn.z_stream = SafeNativeZStreamHandle.Create (mode, gzip, dsn, dsn.feeder); + if (dsn.z_stream == null) { dsn.Dispose (true); return null; } @@ -353,14 +352,10 @@ public void Dispose (bool disposing) io_buffer = null; - IntPtr zz = z_stream; - z_stream = IntPtr.Zero; - if (zz != IntPtr.Zero) - CloseZStream (zz); // This will Flush() the remaining output if any - } - - if (data.IsAllocated) { - data.Free (); + SafeNativeZStreamHandle zz = z_stream; + z_stream = null; + if (zz != null) + zz.Dispose(); } } @@ -476,6 +471,51 @@ static void CheckResult (int result, string where) throw new IOException (error + " " + where); } + class SafeNativeZStreamHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public static SafeNativeZStreamHandle Create (CompressionMode mode, bool gzip, DeflateStreamNative deflateStreamNative, UnmanagedReadOrWrite feeder) + { + GCHandle deflateStreamNativeHandle = GCHandle.Alloc (deflateStreamNative); + SafeNativeZStreamHandle nativeSafeHandle = CreateZStream (mode, gzip, feeder, GCHandle.ToIntPtr (deflateStreamNativeHandle)); + if (nativeSafeHandle == null || nativeSafeHandle.IsInvalid) { + if (deflateStreamNativeHandle.IsAllocated) + deflateStreamNativeHandle.Free(); + + return null; + } + + nativeSafeHandle.deflateStreamNativeHandle = deflateStreamNativeHandle; + + return nativeSafeHandle; + } + + private SafeNativeZStreamHandle () : + base (true) + { + } + + protected override bool ReleaseHandle() + { + if (handle != IntPtr.Zero) { + CloseZStream (handle); // This will Flush() the remaining output if any + } + + if (deflateStreamNativeHandle.IsAllocated) { + deflateStreamNativeHandle.Free (); + } + + return true; + } + + [DllImport (LIBNAME, CallingConvention=CallingConvention.Cdecl)] + static extern SafeNativeZStreamHandle CreateZStream (CompressionMode compress, bool gzip, UnmanagedReadOrWrite feeder, IntPtr data); + + [DllImport (LIBNAME, CallingConvention=CallingConvention.Cdecl)] + static extern int CloseZStream (IntPtr stream); + + GCHandle deflateStreamNativeHandle; + } + #if MONOTOUCH || MONODROID const string LIBNAME = "__Internal"; #else @@ -483,19 +523,13 @@ static void CheckResult (int result, string where) #endif [DllImport (LIBNAME, CallingConvention=CallingConvention.Cdecl)] - static extern IntPtr CreateZStream (CompressionMode compress, bool gzip, UnmanagedReadOrWrite feeder, IntPtr data); - - [DllImport (LIBNAME, CallingConvention=CallingConvention.Cdecl)] - static extern int CloseZStream (IntPtr stream); - - [DllImport (LIBNAME, CallingConvention=CallingConvention.Cdecl)] - static extern int Flush (IntPtr stream); + static extern int Flush (SafeNativeZStreamHandle stream); [DllImport (LIBNAME, CallingConvention=CallingConvention.Cdecl)] - static extern int ReadZStream (IntPtr stream, IntPtr buffer, int length); + static extern int ReadZStream (SafeNativeZStreamHandle stream, IntPtr buffer, int length); [DllImport (LIBNAME, CallingConvention=CallingConvention.Cdecl)] - static extern int WriteZStream (IntPtr stream, IntPtr buffer, int length); + static extern int WriteZStream (SafeNativeZStreamHandle stream, IntPtr buffer, int length); } }