/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ using System; using System.IO; using System.Text; namespace MongoDB.Bson.IO { /// /// Represents a Stream backed by an IByteBuffer. Similar to MemoryStream but backed by an IByteBuffer /// instead of a byte array and also implements the BsonStream interface for higher performance BSON I/O. /// public class ByteBufferStream : BsonStream, IStreamEfficientCopyTo { // private fields private IByteBuffer _buffer; private bool _disposed; private int _length; private readonly bool _ownsBuffer; private int _position; private readonly byte[] _temp = new byte[12]; // constructors /// /// Initializes a new instance of the class. /// /// The buffer. /// Whether the stream owns the buffer and should Dispose it when done. public ByteBufferStream(IByteBuffer buffer, bool ownsBuffer = false) { if (buffer == null) { throw new ArgumentNullException("buffer"); } _buffer = buffer; _ownsBuffer = ownsBuffer; _length = buffer.Length; } // public properties /// /// Gets the buffer. /// /// /// The buffer. /// public IByteBuffer Buffer { get { ThrowIfDisposed(); return _buffer; } } /// public override bool CanRead { get { return !_disposed; } } /// public override bool CanSeek { get { return !_disposed; } } /// public override bool CanTimeout { get { return false; } } /// public override bool CanWrite { get { return !_disposed && !_buffer.IsReadOnly; } } /// public override long Length { get { ThrowIfDisposed(); return _length; } } /// public override long Position { get { ThrowIfDisposed(); return _position; } set { if (value < 0 || value > int.MaxValue) { throw new ArgumentOutOfRangeException("value"); } ThrowIfDisposed(); _position = (int)value; } } // public methods /// public void EfficientCopyTo(Stream destination) { long remainingCount; while ((remainingCount = Length - Position) > 0) { var segment = _buffer.AccessBackingBytes((int)Position); var count = (int)Math.Min(segment.Count, remainingCount); destination.Write(segment.Array, segment.Offset, count); Position += count; } } /// public override void Flush() { ThrowIfDisposed(); // do nothing } /// public override int Read(byte[] buffer, int offset, int count) { if (buffer == null) { throw new ArgumentNullException("buffer"); } if (offset < 0 || offset > buffer.Length) { throw new ArgumentOutOfRangeException("offset"); } if (count < 0 || offset + count > buffer.Length) { throw new ArgumentOutOfRangeException("count"); } ThrowIfDisposed(); if (_position >= _length) { return 0; } var available = _length - _position; if (count > available) { count = available; } _buffer.GetBytes(_position, buffer, offset, count); _position += count; return count; } /// public override int ReadByte() { ThrowIfDisposed(); if (_position >= _length) { return -1; } return _buffer.GetByte(_position++); } /// public override long Seek(long offset, SeekOrigin origin) { ThrowIfDisposed(); long position; switch (origin) { case SeekOrigin.Begin: position = offset; break; case SeekOrigin.Current: position = _position + offset; break; case SeekOrigin.End: position = _length + offset; break; default: throw new ArgumentException("Invalid origin.", "origin"); } if (position < 0) { throw new IOException("Attempted to seek before the beginning of the stream."); } if (position > int.MaxValue) { throw new IOException("Attempted to seek beyond the maximum value that can be represented using 32 bits."); } _position = (int)position; return position; } /// public override void SetLength(long value) { if (value < 0 || value > int.MaxValue) { throw new ArgumentOutOfRangeException("value"); } ThrowIfDisposed(); EnsureWriteable(); _buffer.EnsureCapacity((int)value); _length = (int)value; if (_position > _length) { _position = _length; } } /// public override void Write(byte[] buffer, int offset, int count) { if (buffer == null) { throw new ArgumentNullException("buffer"); } if (offset < 0 || offset > buffer.Length) { throw new ArgumentOutOfRangeException("offset"); } if (count < 0 || offset + count > buffer.Length) { throw new ArgumentOutOfRangeException("count"); } ThrowIfDisposed(); EnsureWriteable(); PrepareToWrite(count); _buffer.SetBytes(_position, buffer, offset, count); SetPositionAfterWrite(_position + count); } /// public override void WriteByte(byte value) { ThrowIfDisposed(); PrepareToWrite(1); _buffer.SetByte(_position, value); SetPositionAfterWrite(_position + 1); } // protected methods /// protected override void Dispose(bool disposing) { if (!_disposed) { if (_ownsBuffer) { _buffer.Dispose(); } _disposed = true; } base.Dispose(disposing); } // private methods private void EnsureWriteable() { if (!CanWrite) { throw new NotSupportedException("Stream is not writeable."); } } private int FindNullByte() { var position = _position; while (position < _length) { var segment = _buffer.AccessBackingBytes(position); var endOfSegmentIndex = segment.Offset + segment.Count; for (var index = segment.Offset; index < endOfSegmentIndex; index++) { if (segment.Array[index] == 0) { return position + (index - segment.Offset); } } position += segment.Count; } throw new EndOfStreamException(); } private void PrepareToWrite(int count) { var minimumCapacity = (long)_position + (long)count; if (minimumCapacity > int.MaxValue) { throw new IOException("Stream was too long."); } _buffer.EnsureCapacity((int)minimumCapacity); _buffer.Length = _buffer.Capacity; if (_length < _position) { _buffer.Clear(_length, _position - _length); } } private byte[] ReadBytes(int count) { ThrowIfEndOfStream(count); var bytes = new byte[count]; _buffer.GetBytes(_position, bytes, 0, count); _position += count; return bytes; } private void SetPositionAfterWrite(int position) { _position = position; if (_length < position) { _length = position; } } private void ThrowIfDisposed() { if (_disposed) { throw new ObjectDisposedException("ByteBufferStream"); } } private void ThrowIfEndOfStream(int count) { var minimumLength = (long)_position + (long)count; if (_length < minimumLength) { if (_position < _length) { _position = _length; } throw new EndOfStreamException(); } } /// public override string ReadCString(UTF8Encoding encoding) { if (encoding == null) { throw new ArgumentNullException("encoding"); } ThrowIfDisposed(); var bytes = ReadCStringBytes(); return Utf8Helper.DecodeUtf8String(bytes.Array, bytes.Offset, bytes.Count, encoding); } /// public override ArraySegment ReadCStringBytes() { ThrowIfDisposed(); ThrowIfEndOfStream(1); var segment = _buffer.AccessBackingBytes(_position); var index = Array.IndexOf(segment.Array, 0, segment.Offset, segment.Count); if (index != -1) { var length = index - segment.Offset; _position += length + 1; // advance over the null byte return new ArraySegment(segment.Array, segment.Offset, length); // without the null byte } else { var nullPosition = FindNullByte(); var length = nullPosition - _position; var cstring = ReadBytes(length + 1); // advance over the null byte return new ArraySegment(cstring, 0, length); // without the null byte } } /// public override Decimal128 ReadDecimal128() { ThrowIfDisposed(); ThrowIfEndOfStream(16); var lowBits = (ulong)ReadInt64(); var highBits = (ulong)ReadInt64(); return Decimal128.FromIEEEBits(highBits, lowBits); } /// public override double ReadDouble() { ThrowIfDisposed(); ThrowIfEndOfStream(8); var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= 8) { _position += 8; return BitConverter.ToDouble(segment.Array, segment.Offset); } else { this.ReadBytes(_temp, 0, 8); return BitConverter.ToDouble(_temp, 0); } } /// public override int ReadInt32() { ThrowIfDisposed(); ThrowIfEndOfStream(4); var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= 4) { _position += 4; var bytes = segment.Array; var offset = segment.Offset; return bytes[offset] | (bytes[offset + 1] << 8) | (bytes[offset + 2] << 16) | (bytes[offset + 3] << 24); } else { this.ReadBytes(_temp, 0, 4); return _temp[0] | (_temp[1] << 8) | (_temp[2] << 16) | (_temp[3] << 24); } } /// public override long ReadInt64() { ThrowIfDisposed(); ThrowIfEndOfStream(8); var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= 8) { _position += 8; return BitConverter.ToInt64(segment.Array, segment.Offset); } else { this.ReadBytes(_temp, 0, 8); return BitConverter.ToInt64(_temp, 0); } } /// public override ObjectId ReadObjectId() { ThrowIfDisposed(); ThrowIfEndOfStream(12); var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= 12) { _position += 12; return new ObjectId(segment.Array, segment.Offset); } else { this.ReadBytes(_temp, 0, 12); return new ObjectId(_temp, 0); } } /// public override IByteBuffer ReadSlice() { ThrowIfDisposed(); var position = _position; var length = ReadInt32(); ThrowIfEndOfStream(length - 4); Position = position + length; return _buffer.GetSlice(position, length); } /// public override string ReadString(UTF8Encoding encoding) { if (encoding == null) { throw new ArgumentNullException("encoding"); } ThrowIfDisposed(); var length = ReadInt32(); if (length <= 0) { var message = string.Format("Invalid string length: {0}.", length); throw new FormatException(message); } var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= length) { ThrowIfEndOfStream(length); if (segment.Array[segment.Offset + length - 1] != 0) { throw new FormatException("String is missing terminating null byte."); } _position += length; return Utf8Helper.DecodeUtf8String(segment.Array, segment.Offset, length - 1, encoding); } else { using var rentedBuffer = ThreadStaticBuffer.RentBuffer(length); var bytes = rentedBuffer.Bytes; this.ReadBytes(bytes, 0, length); if (bytes[length - 1] != 0) { throw new FormatException("String is missing terminating null byte."); } return Utf8Helper.DecodeUtf8String(bytes, 0, length - 1, encoding); } } /// public override void SkipCString() { ThrowIfDisposed(); var nullPosition = FindNullByte(); _position = nullPosition + 1; } /// public override void WriteCString(string value) { if (value == null) { throw new ArgumentNullException("value"); } ThrowIfDisposed(); var maxLength = CStringUtf8Encoding.GetMaxByteCount(value.Length) + 1; PrepareToWrite(maxLength); int actualLength; var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= maxLength) { actualLength = CStringUtf8Encoding.GetBytes(value, segment.Array, segment.Offset, Utf8Encodings.Strict); segment.Array[segment.Offset + actualLength] = 0; } else { // Compare to 128 to preserve original behavior const int maxLengthToUseCStringUtf8EncodingWith = 128; if (maxLength <= maxLengthToUseCStringUtf8EncodingWith) { using var rentedBuffer = ThreadStaticBuffer.RentBuffer(maxLengthToUseCStringUtf8EncodingWith); actualLength = CStringUtf8Encoding.GetBytes(value, rentedBuffer.Bytes, 0, Utf8Encodings.Strict); SetBytes(rentedBuffer.Bytes, actualLength); } else { using var rentedSegmentEncoded = Utf8Encodings.Strict.GetBytesUsingThreadStaticBuffer(value); var segmentEncoded = rentedSegmentEncoded.Segment; actualLength = segmentEncoded.Count; if (Array.IndexOf(segmentEncoded.Array, 0, 0, actualLength) != -1) { throw new ArgumentException("A CString cannot contain null bytes.", "value"); } SetBytes(segmentEncoded.Array, actualLength); } void SetBytes(byte[] bytes, int lenght) { _buffer.SetBytes(_position, bytes, 0, actualLength); _buffer.SetByte(_position + actualLength, 0); } } SetPositionAfterWrite(_position + actualLength + 1); } /// public override void WriteCStringBytes(byte[] value) { if (value == null) { throw new ArgumentNullException("value"); } ThrowIfDisposed(); var length = value.Length; PrepareToWrite(length + 1); _buffer.SetBytes(_position, value, 0, length); _buffer.SetByte(_position + length, 0); SetPositionAfterWrite(_position + length + 1); } /// public override void WriteDecimal128(Decimal128 value) { ThrowIfDisposed(); WriteInt64((long)value.GetIEEELowBits()); WriteInt64((long)value.GetIEEEHighBits()); } /// public override void WriteDouble(double value) { ThrowIfDisposed(); PrepareToWrite(8); var bytes = BitConverter.GetBytes(value); _buffer.SetBytes(_position, bytes, 0, 8); SetPositionAfterWrite(_position + 8); } /// public override void WriteInt32(int value) { ThrowIfDisposed(); PrepareToWrite(4); var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= 4) { segment.Array[segment.Offset] = (byte)value; segment.Array[segment.Offset + 1] = (byte)(value >> 8); segment.Array[segment.Offset + 2] = (byte)(value >> 16); segment.Array[segment.Offset + 3] = (byte)(value >> 24); } else { _temp[0] = (byte)(value); _temp[1] = (byte)(value >> 8); _temp[2] = (byte)(value >> 16); _temp[3] = (byte)(value >> 24); _buffer.SetBytes(_position, _temp, 0, 4); } SetPositionAfterWrite(_position + 4); } /// public override void WriteInt64(long value) { ThrowIfDisposed(); PrepareToWrite(8); var bytes = BitConverter.GetBytes(value); _buffer.SetBytes(_position, bytes, 0, 8); SetPositionAfterWrite(_position + 8); } /// public override void WriteObjectId(ObjectId value) { ThrowIfDisposed(); PrepareToWrite(12); var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= 12) { value.ToByteArray(segment.Array, segment.Offset); } else { var bytes = value.ToByteArray(); _buffer.SetBytes(_position, bytes, 0, 12); } SetPositionAfterWrite(_position + 12); } /// public override void WriteString(string value, UTF8Encoding encoding) { ThrowIfDisposed(); var maxLength = encoding.GetMaxByteCount(value.Length) + 5; PrepareToWrite(maxLength); int actualLength; var segment = _buffer.AccessBackingBytes(_position); if (segment.Count >= maxLength) { actualLength = encoding.GetBytes(value, 0, value.Length, segment.Array, segment.Offset + 4); var lengthPlusOne = actualLength + 1; segment.Array[segment.Offset] = (byte)lengthPlusOne; segment.Array[segment.Offset + 1] = (byte)(lengthPlusOne >> 8); segment.Array[segment.Offset + 2] = (byte)(lengthPlusOne >> 16); segment.Array[segment.Offset + 3] = (byte)(lengthPlusOne >> 24); segment.Array[segment.Offset + 4 + actualLength] = 0; } else { using var rentedSegmentEncoded = encoding.GetBytesUsingThreadStaticBuffer(value); var bytes = rentedSegmentEncoded.Segment.Array; actualLength = rentedSegmentEncoded.Segment.Count; var lengthPlusOneBytes = BitConverter.GetBytes(actualLength + 1); _buffer.SetBytes(_position, lengthPlusOneBytes, 0, 4); _buffer.SetBytes(_position + 4, bytes, 0, actualLength); _buffer.SetByte(_position + 4 + actualLength, 0); } SetPositionAfterWrite(_position + actualLength + 5); } } }