diff --git a/src/base_cmd.h b/src/base_cmd.h index 79874b3a2..9d4dfb666 100644 --- a/src/base_cmd.h +++ b/src/base_cmd.h @@ -37,6 +37,7 @@ const std::string kCmdNameIncrby = "incrby"; const std::string kCmdNameDecrby = "decrby"; const std::string kCmdNameIncrbyFloat = "incrbyfloat"; const std::string kCmdNameStrlen = "strlen"; +const std::string kCmdNameSetBit = "setbit"; const std::string kCmdNameSetEx = "setex"; const std::string kCmdNamePSetEx = "psetex"; const std::string kCmdNameBitOp = "bitop"; diff --git a/src/cmd_kv.cc b/src/cmd_kv.cc index 5bb64709f..eb6b1b31d 100644 --- a/src/cmd_kv.cc +++ b/src/cmd_kv.cc @@ -571,4 +571,67 @@ void GetBitCmd::DoCmd(PClient* client) { return; } +SetBitCmd::SetBitCmd(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsWrite, kAclCategoryWrite | kAclCategoryString) {} + +bool SetBitCmd::DoInitial(PClient* client) { + client->SetKey(client->argv_[1]); + return true; +} + +void SetBitCmd::DoCmd(PClient* client) { + PObject* value = nullptr; + PError err = PSTORE.GetValueByType(client->Key(), value, kPTypeString); + if (err == kPErrorNotExist) { + value = PSTORE.SetValue(client->Key(), PObject::CreateString("")); + err = kPErrorOK; + } + + if (err != kPErrorOK) { + client->AppendInteger(0); + return; + } + + long offset = 0; + long on = 0; + if (!Strtol(client->argv_[2].c_str(), client->argv_[2].size(), &offset) || + !Strtol(client->argv_[3].c_str(), client->argv_[3].size(), &on)) { + client->SetRes(CmdRes::kInvalidInt); + return; + } + + if (offset < 0 || offset > kStringMaxBytes) { + client->AppendInteger(0); + return; + } + + PString* pStringPtr = value->CastString(); + if (!pStringPtr) { + client->AppendInteger(0); + return; + } + + PString& newVal = *pStringPtr; + + size_t bytes = offset / 8; + size_t bits = offset % 8; + + if (bytes + 1 > newVal.size()) { + newVal.resize(bytes + 1, '\0'); + } + + const char oldByte = newVal[bytes]; + char& byte = newVal[bytes]; + if (on) { + byte |= (0x1 << bits); + } else { + byte &= ~(0x1 << bits); + } + + value->Reset(new PString(newVal)); + value->encoding = kPEncodeRaw; + client->AppendInteger((oldByte & (0x1 << bits)) ? 1 : 0); + return; +} + } // namespace pikiwidb diff --git a/src/cmd_kv.h b/src/cmd_kv.h index b94c53ed4..5540b5be3 100644 --- a/src/cmd_kv.h +++ b/src/cmd_kv.h @@ -170,6 +170,17 @@ class DecrbyCmd : public BaseCmd { void DoCmd(PClient *client) override; }; +class SetBitCmd : public BaseCmd { + public: + SetBitCmd(const std::string &name, int16_t arity); + + protected: + bool DoInitial(PClient *client) override; + + private: + void DoCmd(PClient *client) override; +}; + class GetBitCmd : public BaseCmd { public: GetBitCmd(const std::string &name, int16_t arity); diff --git a/src/cmd_table_manager.cc b/src/cmd_table_manager.cc index 8556a8f5b..29a39d167 100644 --- a/src/cmd_table_manager.cc +++ b/src/cmd_table_manager.cc @@ -58,6 +58,7 @@ void CmdTableManager::InitCmdTable() { ADD_COMMAND(BitOp, -4); ADD_COMMAND(BitCount, -2); ADD_COMMAND(GetBit, 3); + ADD_COMMAND(SetBit, 4); // hash ADD_COMMAND(HSet, -4);