123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
module thrift.protocol.binary;

import std.array : uninitializedArray;
import std.typetuple : allSatisfy, TypeTuple;
import thrift.protocol.base;
import thrift.transport.base;
import thrift.internal.endian;

/**
 * TProtocol implementation of the Binary Thrift protocol.
 */
final class TBinaryProtocol(Transport = TTransport) if (
  isTTransport!Transport
) : TProtocol {

  /**
   * Constructs a new instance.
   *
   * Params:
   *   trans = The transport to use.
   *   containerSizeLimit = If positive, the container size is limited to the
   *     given number of items.
   *   stringSizeLimit = If positive, the string length is limited to the
   *     given number of bytes.
   *   strictRead = If false, old peers which do not include the protocol
   *     version are tolerated.
   *   strictWrite = Whether to include the protocol version in the header.
   */
  this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0,
    bool strictRead = false, bool strictWrite = true
  ) {
    trans_ = trans;
    this.containerSizeLimit = containerSizeLimit;
    this.stringSizeLimit = stringSizeLimit;
    this.strictRead = strictRead;
    this.strictWrite = strictWrite;
  }

  Transport transport() @property {
    return trans_;
  }

  void reset() {}

  /**
   * If false, old peers which do not include the protocol version in the
   * message header are tolerated.
   *
   * Defaults to false.
   */
  bool strictRead;

  /**
   * Whether to include the protocol version in the message header (older
   * versions didn't).
   *
   * Defaults to true.
   */
  bool strictWrite;

  /**
   * If positive, limits the number of items of deserialized containers to the
   * given amount.
   *
   * This is useful to avoid allocating excessive amounts of memory when broken
   * data is received. If the limit is exceeded, a SIZE_LIMIT-type
   * TProtocolException is thrown.
   *
   * Defaults to zero (no limit).
   */
  int containerSizeLimit;

  /**
   * If positive, limits the length of deserialized strings/binary data to the
   * given number of bytes.
   *
   * This is useful to avoid allocating excessive amounts of memory when broken
   * data is received. If the limit is exceeded, a SIZE_LIMIT-type
   * TProtocolException is thrown.
   *
   * Defaults to zero (no limit).
   */
  int stringSizeLimit;

  /*
   * Writing methods.
   */

  void writeBool(bool b) {
    writeByte(b ? 1 : 0);
  }

  void writeByte(byte b) {
    trans_.write((cast(ubyte*)&b)[0 .. 1]);
  }

  void writeI16(short i16) {
    short net = hostToNet(i16);
    trans_.write((cast(ubyte*)&net)[0 .. 2]);
  }

  void writeI32(int i32) {
    int net = hostToNet(i32);
    trans_.write((cast(ubyte*)&net)[0 .. 4]);
  }

  void writeI64(long i64) {
    long net = hostToNet(i64);
    trans_.write((cast(ubyte*)&net)[0 .. 8]);
  }

  void writeDouble(double dub) {
    static assert(double.sizeof == ulong.sizeof);
    auto bits = hostToNet(*cast(ulong*)(&dub));
    trans_.write((cast(ubyte*)&bits)[0 .. 8]);
  }

  void writeString(string str) {
    writeBinary(cast(ubyte[])str);
  }

  void writeBinary(ubyte[] buf) {
    assert(buf.length <= int.max);
    writeI32(cast(int)buf.length);
    trans_.write(buf);
  }

  void writeMessageBegin(TMessage message) {
    if (strictWrite) {
      int versn = VERSION_1 | message.type;
      writeI32(versn);
      writeString(message.name);
      writeI32(message.seqid);
    } else {
      writeString(message.name);
      writeByte(message.type);
      writeI32(message.seqid);
    }
  }
  void writeMessageEnd() {}

  void writeStructBegin(TStruct tstruct) {}
  void writeStructEnd() {}

  void writeFieldBegin(TField field) {
    writeByte(field.type);
    writeI16(field.id);
  }
  void writeFieldEnd() {}

  void writeFieldStop() {
    writeByte(TType.STOP);
  }

  void writeListBegin(TList list) {
    assert(list.size <= int.max);
    writeByte(list.elemType);
    writeI32(cast(int)list.size);
  }
  void writeListEnd() {}

  void writeMapBegin(TMap map) {
    assert(map.size <= int.max);
    writeByte(map.keyType);
    writeByte(map.valueType);
    writeI32(cast(int)map.size);
  }
  void writeMapEnd() {}

  void writeSetBegin(TSet set) {
    assert(set.size <= int.max);
    writeByte(set.elemType);
    writeI32(cast(int)set.size);
  }
  void writeSetEnd() {}


  /*
   * Reading methods.
   */

  bool readBool() {
    return readByte() != 0;
  }

  byte readByte() {
    ubyte[1] b = void;
    trans_.readAll(b);
    return cast(byte)b[0];
  }

  short readI16() {
    IntBuf!short b = void;
    trans_.readAll(b.bytes);
    return netToHost(b.value);
  }

  int readI32() {
    IntBuf!int b = void;
    trans_.readAll(b.bytes);
    return netToHost(b.value);
  }

  long readI64() {
    IntBuf!long b = void;
    trans_.readAll(b.bytes);
    return netToHost(b.value);
  }

  double readDouble() {
    IntBuf!long b = void;
    trans_.readAll(b.bytes);
    b.value = netToHost(b.value);
    return *cast(double*)(&b.value);
  }

  string readString() {
    return cast(string)readBinary();
  }

  ubyte[] readBinary() {
    return readBinaryBody(readSize(stringSizeLimit));
  }

  TMessage readMessageBegin() {
    TMessage msg = void;

    int size = readI32();
    if (size < 0) {
      int versn = size & VERSION_MASK;
      if (versn != VERSION_1) {
        throw new TProtocolException("Bad protocol version.",
          TProtocolException.Type.BAD_VERSION);
      }

      msg.type = cast(TMessageType)(size & MESSAGE_TYPE_MASK);
      msg.name = readString();
      msg.seqid = readI32();
    } else {
      if (strictRead) {
        throw new TProtocolException(
          "Protocol version missing, old client?",
          TProtocolException.Type.BAD_VERSION);
      } else {
        if (size < 0) {
          throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
        }
        msg.name = cast(string)readBinaryBody(size);
        msg.type = cast(TMessageType)(readByte());
        msg.seqid = readI32();
      }
    }

    return msg;
  }
  void readMessageEnd() {}

  TStruct readStructBegin() {
    return TStruct();
  }
  void readStructEnd() {}

  TField readFieldBegin() {
    TField f = void;
    f.name = null;
    f.type = cast(TType)readByte();
    if (f.type == TType.STOP) return f;
    f.id = readI16();
    return f;
  }
  void readFieldEnd() {}

  TList readListBegin() {
    return TList(cast(TType)readByte(), readSize(containerSizeLimit));
  }
  void readListEnd() {}

  TMap readMapBegin() {
    return TMap(cast(TType)readByte(), cast(TType)readByte(),
      readSize(containerSizeLimit));
  }
  void readMapEnd() {}

  TSet readSetBegin() {
    return TSet(cast(TType)readByte(), readSize(containerSizeLimit));
  }
  void readSetEnd() {}

private:
  ubyte[] readBinaryBody(int size) {
    if (size == 0) {
      return null;
    }

    auto buf = uninitializedArray!(ubyte[])(size);
    trans_.readAll(buf);
    return buf;
  }

  int readSize(int limit) {
    auto size = readI32();
    if (size < 0) {
      throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
    } else if (limit > 0 && size > limit) {
      throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT);
    }
    return size;
  }

  enum MESSAGE_TYPE_MASK = 0x000000ff;
  enum VERSION_MASK = 0xffff0000;
  enum VERSION_1 = 0x80010000;

  Transport trans_;
}

/**
 * TBinaryProtocol construction helper to avoid having to explicitly specify
 * the transport type, i.e. to allow the constructor being called using IFTI
 * (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla
 * enhancement requet 6082)).
 */
TBinaryProtocol!Transport tBinaryProtocol(Transport)(Transport trans,
  int containerSizeLimit = 0, int stringSizeLimit = 0,
  bool strictRead = false, bool strictWrite = true
) if (isTTransport!Transport) {
  return new TBinaryProtocol!Transport(trans, containerSizeLimit,
    stringSizeLimit, strictRead, strictWrite);
}

unittest {
  import std.exception;
  import thrift.transport.memory;

  // Check the message header format.
  auto buf = new TMemoryBuffer;
  auto binary = tBinaryProtocol(buf);
  binary.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0));

  auto header = new ubyte[15];
  buf.readAll(header);
  enforce(header == [
    128, 1, 0, 1, // Version 1, TMessageType.CALL
    0, 0, 0, 3, // Method name length
    102, 111, 111, // Method name ("foo")
    0, 0, 0, 0, // Sequence id
  ]);
}

unittest {
  import thrift.internal.test.protocol;
  testContainerSizeLimit!(TBinaryProtocol!())();
  testStringSizeLimit!(TBinaryProtocol!())();
}

/**
 * TProtocolFactory creating a TBinaryProtocol instance for passed in
 * transports.
 *
 * The optional Transports template tuple parameter can be used to specify
 * one or more TTransport implementations to specifically instantiate
 * TBinaryProtocol for. If the actual transport types encountered at
 * runtime match one of the transports in the list, a specialized protocol
 * instance is created. Otherwise, a generic TTransport version is used.
 */
class TBinaryProtocolFactory(Transports...) if (
  allSatisfy!(isTTransport, Transports)
) : TProtocolFactory {
  ///
  this (int containerSizeLimit = 0, int stringSizeLimit = 0,
    bool strictRead = false, bool strictWrite = true
  ) {
    strictRead_ = strictRead;
    strictWrite_ = strictWrite;
    containerSizeLimit_ = containerSizeLimit;
    stringSizeLimit_ = stringSizeLimit;
  }

  TProtocol getProtocol(TTransport trans) const {
    foreach (Transport; TypeTuple!(Transports, TTransport)) {
      auto concreteTrans = cast(Transport)trans;
      if (concreteTrans) {
        return new TBinaryProtocol!Transport(concreteTrans,
          containerSizeLimit_, stringSizeLimit_, strictRead_, strictWrite_);
      }
    }
    throw new TProtocolException(
      "Passed null transport to TBinaryProtocolFactoy.");
  }

protected:
  bool strictRead_;
  bool strictWrite_;
  int containerSizeLimit_;
  int stringSizeLimit_;
}