diff --git a/build_support/clang_format_exclusions.txt b/build_support/clang_format_exclusions.txt deleted file mode 100755 index e69de29bb..000000000 diff --git a/src/base_cmd.h b/src/base_cmd.h index 1b48c88eb..ce6909e16 100644 --- a/src/base_cmd.h +++ b/src/base_cmd.h @@ -34,6 +34,7 @@ const std::string kCmdNameStrlen = "strlen"; const std::string kCmdNameSetex = "setex"; const std::string kCmdNamePsetex = "psetex"; const std::string kCmdNameSetnx = "setnx"; +const std::string kCmdNameIncrbyfloat = "incrbyfloat"; const std::string kCmdNameGetBit = "getbit"; // multi diff --git a/src/cmd_kv.cc b/src/cmd_kv.cc index ed9e67b84..537bc898e 100644 --- a/src/cmd_kv.cc +++ b/src/cmd_kv.cc @@ -6,7 +6,8 @@ */ #include "cmd_kv.h" -#include "pstd/pstd_string.h" +#include "common.h" +#include "pstd_string.h" #include "pstd_util.h" #include "store.h" @@ -413,10 +414,10 @@ bool IncrbyCmd::DoInitial(PClient* client) { } void IncrbyCmd::DoCmd(PClient* client) { - int64_t new_value_ = 0; + int64_t new_value = 0; int64_t by_ = 0; pstd::String2int(client->argv_[2].data(), client->argv_[2].size(), &by_); - PError err = PSTORE.Incrby(client->Key(), by_, &new_value_); + PError err = PSTORE.Incrby(client->Key(), by_, &new_value); switch (err) { case PError_type: client->SetRes(CmdRes::kInvalidInt); @@ -427,7 +428,7 @@ void IncrbyCmd::DoCmd(PClient* client) { client->AppendInteger(by_); break; case PError_ok: - client->AppendInteger(new_value_); + client->AppendInteger(new_value); break; default: client->SetRes(CmdRes::kErrOther, "incrby cmd error"); @@ -435,6 +436,40 @@ void IncrbyCmd::DoCmd(PClient* client) { } } +IncrbyfloatCmd::IncrbyfloatCmd(const std::string& name, int16_t arity) + : BaseCmd(name, arity, CmdFlagsWrite, AclCategoryWrite | AclCategoryString) {} + +bool IncrbyfloatCmd::DoInitial(PClient* client) { + long double by_ = 0.00f; + if (StrToLongDouble(client->argv_[2].data(), client->argv_[2].size(), &by_)) { + client->SetRes(CmdRes::kInvalidFloat); + return false; + } + client->SetKey(client->argv_[1]); + return true; +} + +void IncrbyfloatCmd::DoCmd(PClient* client) { + std::string new_value; + PError err = PSTORE.Incrbyfloat(client->argv_[1], client->argv_[2], &new_value); + switch (err) { + case PError_type: + client->SetRes(CmdRes::kInvalidFloat); + break; + case PError_notExist: // key not exist, set a new value + PSTORE.ClearExpire(client->Key()); // clear key's old ttl + PSTORE.SetValue(client->Key(), PObject::CreateString(client->argv_[2])); + client->AppendString(client->argv_[2]); + break; + case PError_ok: + client->AppendString(new_value); + break; + default: + client->SetRes(CmdRes::kErrOther, "incrbyfloat cmd error"); + break; + } +} + SetnxCmd::SetnxCmd(const std::string& name, int16_t arity) : BaseCmd(name, arity, CmdFlagsWrite, AclCategoryWrite | AclCategoryString) {} @@ -499,4 +534,4 @@ void GetBitCmd::DoCmd(PClient* client) { return; } -} // namespace pikiwidb \ No newline at end of file +} // namespace pikiwidb diff --git a/src/cmd_kv.h b/src/cmd_kv.h index 4c4244b32..ed66d8faf 100644 --- a/src/cmd_kv.h +++ b/src/cmd_kv.h @@ -171,4 +171,15 @@ class GetBitCmd : public BaseCmd { void DoCmd(PClient *client) override; }; +class IncrbyfloatCmd : public BaseCmd { + public: + IncrbyfloatCmd(const std::string &name, int16_t arity); + + protected: + bool DoInitial(PClient *client) override; + + private: + void DoCmd(PClient *client) override; +}; + } // namespace pikiwidb diff --git a/src/cmd_table_manager.cc b/src/cmd_table_manager.cc index a2e64dce5..f4af3fe3c 100644 --- a/src/cmd_table_manager.cc +++ b/src/cmd_table_manager.cc @@ -57,6 +57,8 @@ void CmdTableManager::InitCmdTable() { cmds_->insert(std::make_pair(kCmdNameBitCount, std::move(bitcountPtr))); std::unique_ptr incrbyPtr = std::make_unique(kCmdNameIncrby, 3); cmds_->insert(std::make_pair(kCmdNameIncrby, std::move(incrbyPtr))); + std::unique_ptr incrbyfloatPtr = std::make_unique(kCmdNameIncrbyfloat, 3); + cmds_->insert(std::make_pair(kCmdNameIncrbyfloat, std::move(incrbyfloatPtr))); std::unique_ptr strlenPtr = std::make_unique(kCmdNameStrlen, 2); cmds_->insert(std::make_pair(kCmdNameStrlen, std::move(strlenPtr))); std::unique_ptr setexPtr = std::make_unique(kCmdNameSetex, 4); diff --git a/src/common.cc b/src/common.cc index 1f407aed0..c1cbdcb46 100644 --- a/src/common.cc +++ b/src/common.cc @@ -6,12 +6,15 @@ */ #include "common.h" +#include #include #include #include #include +#include #include #include +#include #include "unbounded_buffer.h" namespace pikiwidb { @@ -44,6 +47,50 @@ struct PErrorInfo g_errorInfo[] = { int Double2Str(char* ptr, std::size_t nBytes, double val) { return snprintf(ptr, nBytes - 1, "%.6g", val); } +int StrToLongDouble(const char* s, size_t slen, long double* ldval) { + char* pEnd; + std::string t(s, slen); + if (t.find(' ') != std::string::npos) { + return -1; + } + long double d = strtold(s, &pEnd); + if (pEnd != s + slen) { + return -1; + } + + if (ldval) { + *ldval = d; + } + return 0; +} + +int LongDoubleToStr(long double ldval, std::string* value) { + if (isnan(ldval)) { + return -1; + } else if (isinf(ldval)) { + if (ldval > 0) { + *value = "inf"; + } else { + *value = "-inf"; + } + return -1; + } else { + std::ostringstream oss; + oss << std::setprecision(15) << ldval; + *value = oss.str(); + + // Remove trailing zeroes after the '.' + size_t dotPos = value->find('.'); + if (dotPos != std::string::npos) { + value->erase(value->find_last_not_of('0') + 1, std::string::npos); + if (value->back() == '.') { + value->pop_back(); + } + } + return 0; + } +} + bool TryStr2Long(const char* ptr, size_t nBytes, long& val) { bool negtive = false; size_t i = 0; diff --git a/src/common.h b/src/common.h index 572b7831a..bc0329509 100644 --- a/src/common.h +++ b/src/common.h @@ -94,6 +94,7 @@ enum PError { PError_moduleinit = 16, PError_moduleuninit = 17, PError_modulerepeat = 18, + PError_overflow = 19, PError_max, }; @@ -145,6 +146,8 @@ inline std::size_t Number2Str(char* ptr, std::size_t nBytes, T val) { } int Double2Str(char* ptr, std::size_t nBytes, double val); +int StrToLongDouble(const char* s, size_t slen, long double* ldval); +int LongDoubleToStr(long double ldval, std::string* value); bool TryStr2Long(const char* ptr, std::size_t nBytes, long& val); // only for decimal bool Strtol(const char* ptr, std::size_t nBytes, long* outVal); bool Strtoll(const char* ptr, std::size_t nBytes, long long* outVal); diff --git a/src/store.cc b/src/store.cc index 7379aa902..a7edb78c1 100644 --- a/src/store.cc +++ b/src/store.cc @@ -9,12 +9,12 @@ #include #include #include "client.h" +#include "common.h" #include "config.h" #include "event_loop.h" #include "leveldb.h" #include "log.h" #include "multi.h" - namespace pikiwidb { uint32_t PObject::lruclock = static_cast(::time(nullptr)); @@ -589,6 +589,50 @@ PError PStore::Incrby(const PString& key, int64_t value, int64_t* ret) { return PError_ok; } +PError PStore::Incrbyfloat(const PString& key, std::string value, std::string* ret) { + PObject* old_value = nullptr; + long double old_number = 0.00f; + long double long_double_by = 0.00f; + auto db = &dbs_[dbno_]; + + if (StrToLongDouble(value.data(), value.size(), &long_double_by)) { + return PError_type; + } + + // shared when reading + std::unique_lock lock(mutex_); + PError err = getValueByType(key, old_value, PType_string); + if (err != PError_ok) { + return err; + } + + auto old_number_str = pikiwidb::GetDecodedString(old_value); + // old number to long double + if (StrToLongDouble(old_number_str->c_str(), old_number_str->size(), &old_number)) { + return PError_type; + } + + std::string total_string; + long double total = old_number + long_double_by; + if (LongDoubleToStr(total, &total_string)) { + return PError_overflow; + } + + *ret = total_string; + PObject new_value; + new_value = PObject::CreateString(total_string); + new_value.lru = PObject::lruclock; + auto [realObj, status] = db->insert_or_assign(key, std::move(new_value)); + const PObject& obj = realObj->second; + + // put this key to sync list + if (!waitSyncKeys_.empty()) { + waitSyncKeys_[dbno_].insert_or_assign(key, &obj); + } + + return PError_ok; +} + void PStore::SetExpire(const PString& key, uint64_t when) const { expiredDBs_[dbno_].SetExpire(key, when); } int64_t PStore::TTL(const PString& key, uint64_t now) { return expiredDBs_[dbno_].TTL(key, now); } diff --git a/src/store.h b/src/store.h index b16cb4cc3..fc2c8f9d4 100644 --- a/src/store.h +++ b/src/store.h @@ -121,6 +121,7 @@ class PStore { PObject* SetValue(const PString& key, PObject&& value); // incr PError Incrby(const PString& key, int64_t value, int64_t* ret); + PError Incrbyfloat(const PString& key, std::string value, std::string* ret); // for expire key enum ExpireResult : std::int8_t {