00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #ifndef FST_LIB_CACHE_H__
00023 #define FST_LIB_CACHE_H__
00024
00025 #include <vector>
00026 using std::vector;
00027 #include <list>
00028
00029 #include <fst/vector-fst.h>
00030
00031 DECLARE_bool(fst_default_cache_gc);
00032 DECLARE_int64(fst_default_cache_gc_limit);
00033
00034 namespace fst {
00035
00036 struct CacheOptions {
00037 bool gc;
00038 size_t gc_limit;
00039
00040
00041 CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
00042 CacheOptions()
00043 : gc(FLAGS_fst_default_cache_gc),
00044 gc_limit(FLAGS_fst_default_cache_gc_limit) {}
00045 };
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062 template <class S>
00063 class CacheBaseImpl : public VectorFstBaseImpl<S> {
00064 public:
00065 using FstImpl<typename S::Arc>::Type;
00066 using VectorFstBaseImpl<S>::NumStates;
00067 using VectorFstBaseImpl<S>::AddState;
00068 using VectorFstBaseImpl<S>::SetState;
00069
00070 typedef S State;
00071 typedef typename S::Arc Arc;
00072 typedef typename Arc::Weight Weight;
00073 typedef typename Arc::StateId StateId;
00074
00075 CacheBaseImpl()
00076 : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
00077 cache_first_state_id_(kNoStateId), cache_first_state_(0),
00078 cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0),
00079 cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
00080 FLAGS_fst_default_cache_gc_limit == 0 ?
00081 FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {}
00082
00083 explicit CacheBaseImpl(const CacheOptions &opts)
00084 : cache_start_(false), nknown_states_(0),
00085 min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
00086 cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
00087 cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
00088 opts.gc_limit : kMinCacheLimit) {}
00089
00090
00091 CacheBaseImpl(const CacheBaseImpl &impl)
00092 : cache_start_(false), nknown_states_(0),
00093 min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
00094 cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
00095 cache_limit_(impl.cache_limit_) {}
00096
00097 ~CacheBaseImpl() {
00098 delete cache_first_state_;
00099 }
00100
00101
00102 const S *GetState(StateId s) const {
00103 if (s == cache_first_state_id_)
00104 return cache_first_state_;
00105 else
00106 return VectorFstBaseImpl<S>::GetState(s);
00107 }
00108
00109
00110 S *GetState(StateId s) {
00111 if (s == cache_first_state_id_)
00112 return cache_first_state_;
00113 else
00114 return VectorFstBaseImpl<S>::GetState(s);
00115 }
00116
00117
00118 const S *CheckState(StateId s) const {
00119 if (s == cache_first_state_id_)
00120 return cache_first_state_;
00121 else if (s < NumStates())
00122 return VectorFstBaseImpl<S>::GetState(s);
00123 else
00124 return 0;
00125 }
00126
00127
00128 S *ExtendState(StateId s) {
00129 if (s == cache_first_state_id_) {
00130 return cache_first_state_;
00131 } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
00132 cache_first_state_id_ = s;
00133 cache_first_state_ = new S;
00134 return cache_first_state_;
00135 } else if (cache_first_state_id_ != kNoStateId &&
00136 cache_first_state_->ref_count == 0) {
00137 cache_first_state_id_ = s;
00138 cache_first_state_->Reset();
00139 return cache_first_state_;
00140 } else {
00141 while (NumStates() <= s)
00142 AddState(0);
00143 if (!VectorFstBaseImpl<S>::GetState(s)) {
00144 SetState(s, new S);
00145 if (cache_first_state_id_ != kNoStateId) {
00146 while (NumStates() <= cache_first_state_id_)
00147 AddState(0);
00148 SetState(cache_first_state_id_, cache_first_state_);
00149 if (cache_gc_) {
00150 cache_states_.push_back(cache_first_state_id_);
00151 cache_size_ += sizeof(S) +
00152 cache_first_state_->arcs.capacity() * sizeof(Arc);
00153 }
00154 cache_limit_ = kMinCacheLimit;
00155 cache_first_state_id_ = kNoStateId;
00156 cache_first_state_ = 0;
00157 }
00158 if (cache_gc_) {
00159 cache_states_.push_back(s);
00160 cache_size_ += sizeof(S);
00161 if (cache_size_ > cache_limit_)
00162 GC(s, false);
00163 }
00164 }
00165 return VectorFstBaseImpl<S>::GetState(s);
00166 }
00167 }
00168
00169 void SetStart(StateId s) {
00170 VectorFstBaseImpl<S>::SetStart(s);
00171 cache_start_ = true;
00172 if (s >= nknown_states_)
00173 nknown_states_ = s + 1;
00174 }
00175
00176 void SetFinal(StateId s, Weight w) {
00177 S *state = ExtendState(s);
00178 state->final = w;
00179 state->flags |= kCacheFinal | kCacheRecent;
00180 }
00181
00182 void AddArc(StateId s, const Arc &arc) {
00183 S *state = ExtendState(s);
00184 state->arcs.push_back(arc);
00185 }
00186
00187
00188 void SetArcs(StateId s) {
00189 S *state = ExtendState(s);
00190 vector<Arc> &arcs = state->arcs;
00191 state->niepsilons = state->noepsilons = 0;
00192 for (size_t a = 0; a < arcs.size(); ++a) {
00193 const Arc &arc = arcs[a];
00194 if (arc.nextstate >= nknown_states_)
00195 nknown_states_ = arc.nextstate + 1;
00196 if (arc.ilabel == 0)
00197 ++state->niepsilons;
00198 if (arc.olabel == 0)
00199 ++state->noepsilons;
00200 }
00201 ExpandedState(s);
00202 state->flags |= kCacheArcs | kCacheRecent;
00203 if (cache_gc_ && s != cache_first_state_id_) {
00204 cache_size_ += arcs.capacity() * sizeof(Arc);
00205 if (cache_size_ > cache_limit_)
00206 GC(s, false);
00207 }
00208 };
00209
00210 void ReserveArcs(StateId s, size_t n) {
00211 S *state = ExtendState(s);
00212 state->arcs.reserve(n);
00213 }
00214
00215
00216 bool HasStart() const { return cache_start_; }
00217
00218
00219 bool HasFinal(StateId s) const {
00220 const S *state = CheckState(s);
00221 if (state && state->flags & kCacheFinal) {
00222 state->flags |= kCacheRecent;
00223 return true;
00224 } else {
00225 return false;
00226 }
00227 }
00228
00229
00230 bool HasArcs(StateId s) const {
00231 const S *state = CheckState(s);
00232 if (state && state->flags & kCacheArcs) {
00233 state->flags |= kCacheRecent;
00234 return true;
00235 } else {
00236 return false;
00237 }
00238 }
00239
00240 Weight Final(StateId s) const {
00241 const S *state = GetState(s);
00242 return state->final;
00243 }
00244
00245 size_t NumArcs(StateId s) const {
00246 const S *state = GetState(s);
00247 return state->arcs.size();
00248 }
00249
00250 size_t NumInputEpsilons(StateId s) const {
00251 const S *state = GetState(s);
00252 return state->niepsilons;
00253 }
00254
00255 size_t NumOutputEpsilons(StateId s) const {
00256 const S *state = GetState(s);
00257 return state->noepsilons;
00258 }
00259
00260
00261 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
00262 const S *state = GetState(s);
00263 data->base = 0;
00264 data->narcs = state->arcs.size();
00265 data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
00266 data->ref_count = &(state->ref_count);
00267 ++(*data->ref_count);
00268 }
00269
00270
00271 StateId NumKnownStates() const { return nknown_states_; }
00272
00273
00274 void UpdateNumKnownStates(StateId s) {
00275 if (s >= nknown_states_)
00276 nknown_states_ = s + 1;
00277 }
00278
00279
00280 StateId MinUnexpandedState() const {
00281 while (min_unexpanded_state_id_ < expanded_states_.size() &&
00282 expanded_states_[min_unexpanded_state_id_])
00283 ++min_unexpanded_state_id_;
00284 return min_unexpanded_state_id_;
00285 }
00286
00287
00288
00289
00290
00291
00292 void GC(StateId current, bool free_recent) {
00293 if (!cache_gc_)
00294 return;
00295 VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
00296 << "), free recently cached = " << free_recent
00297 << ", cache size = " << cache_size_
00298 << ", cache limit = " << cache_limit_ << "\n";
00299 typename list<StateId>::iterator siter = cache_states_.begin();
00300
00301 size_t cache_target = (2 * cache_limit_)/3 + 1;
00302 while (siter != cache_states_.end()) {
00303 StateId s = *siter;
00304 S* state = VectorFstBaseImpl<S>::GetState(s);
00305 if (cache_size_ > cache_target && state->ref_count == 0 &&
00306 (free_recent || !(state->flags & kCacheRecent)) && s != current) {
00307 cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
00308 delete state;
00309 SetState(s, 0);
00310 cache_states_.erase(siter++);
00311 } else {
00312 state->flags &= ~kCacheRecent;
00313 ++siter;
00314 }
00315 }
00316 if (!free_recent && cache_size_ > cache_target) {
00317 GC(current, true);
00318 } else {
00319 while (cache_size_ > cache_target) {
00320 cache_limit_ *= 2;
00321 cache_target *= 2;
00322 }
00323 }
00324 VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
00325 << "), free recently cached = " << free_recent
00326 << ", cache size = " << cache_size_
00327 << ", cache limit = " << cache_limit_ << "\n";
00328 }
00329
00330 void ExpandedState(StateId s) {
00331 if (s < min_unexpanded_state_id_)
00332 return;
00333 while (expanded_states_.size() <= s)
00334 expanded_states_.push_back(false);
00335 expanded_states_[s] = true;
00336 }
00337
00338
00339 bool GetCacheGc() const { return cache_gc_; }
00340 size_t GetCacheLimit() const { return cache_limit_; }
00341 size_t GetCacheSize() const { return cache_size_; }
00342
00343 private:
00344 static const uint32 kCacheFinal = 0x0001;
00345 static const uint32 kCacheArcs = 0x0002;
00346 static const uint32 kCacheRecent = 0x0004;
00347
00348 static const size_t kMinCacheLimit = 8096;
00349
00350 bool cache_start_;
00351 StateId nknown_states_;
00352 vector<bool> expanded_states_;
00353 mutable StateId min_unexpanded_state_id_;
00354 StateId cache_first_state_id_;
00355 S *cache_first_state_;
00356 list<StateId> cache_states_;
00357 bool cache_gc_;
00358 size_t cache_size_;
00359 size_t cache_limit_;
00360
00361 void InitStateIterator(StateIteratorData<Arc> *);
00362 void operator=(const CacheBaseImpl<Arc> &impl);
00363 };
00364
00365 template <class S> const uint32 CacheBaseImpl<S>::kCacheFinal;
00366 template <class S> const uint32 CacheBaseImpl<S>::kCacheArcs;
00367 template <class S> const uint32 CacheBaseImpl<S>::kCacheRecent;
00368 template <class S> const size_t CacheBaseImpl<S>::kMinCacheLimit;
00369
00370
00371
00372 template <class A>
00373 struct CacheState {
00374 typedef A Arc;
00375 typedef typename A::Weight Weight;
00376 typedef typename A::StateId StateId;
00377
00378 CacheState() : final(Weight::Zero()), flags(0), ref_count(0) {}
00379
00380 void Reset() {
00381 flags = 0;
00382 ref_count = 0;
00383 arcs.resize(0);
00384 }
00385
00386 Weight final;
00387 vector<A> arcs;
00388 size_t niepsilons;
00389 size_t noepsilons;
00390 mutable uint32 flags;
00391 mutable int ref_count;
00392 };
00393
00394
00395 template <class A>
00396 class CacheImpl : public CacheBaseImpl< CacheState<A> > {
00397 public:
00398 typedef CacheState<A> State;
00399
00400 CacheImpl() {}
00401
00402 explicit CacheImpl(const CacheOptions &opts)
00403 : CacheBaseImpl< CacheState<A> >(opts) {}
00404
00405 CacheImpl(const CacheImpl<State> &impl) : CacheBaseImpl<State>(impl) {}
00406
00407 private:
00408 void operator=(const CacheImpl<State> &impl);
00409 };
00410
00411
00412
00413
00414
00415
00416 template <class F>
00417 class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
00418 public:
00419 typedef typename F::Arc Arc;
00420 typedef typename Arc::StateId StateId;
00421 typedef typename F::State State;
00422 typedef CacheBaseImpl<State> Impl;
00423
00424 CacheStateIterator(const F &fst, Impl *impl)
00425 : fst_(fst), impl_(impl), s_(0) {}
00426
00427
00428
00429
00430 explicit CacheStateIterator(const F &fst)
00431 : fst_(fst), impl_(fst.impl_), s_(0) {}
00432
00433 bool Done() const {
00434 if (s_ < impl_->NumKnownStates())
00435 return false;
00436 fst_.Start();
00437 if (s_ < impl_->NumKnownStates())
00438 return false;
00439 for (StateId u = impl_->MinUnexpandedState();
00440 u < impl_->NumKnownStates();
00441 u = impl_->MinUnexpandedState()) {
00442
00443 ArcIterator<F> aiter(fst_, u);
00444 aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
00445 for (; !aiter.Done(); aiter.Next())
00446 impl_->UpdateNumKnownStates(aiter.Value().nextstate);
00447 impl_->ExpandedState(u);
00448 if (s_ < impl_->NumKnownStates())
00449 return false;
00450 }
00451 return true;
00452 }
00453
00454 StateId Value() const { return s_; }
00455
00456 void Next() { ++s_; }
00457
00458 void Reset() { s_ = 0; }
00459
00460 private:
00461
00462
00463
00464 virtual bool Done_() const { return Done(); }
00465 virtual StateId Value_() const { return Value(); }
00466 virtual void Next_() { Next(); }
00467 virtual void Reset_() { Reset(); }
00468
00469 const F &fst_;
00470 Impl *impl_;
00471 StateId s_;
00472 };
00473
00474
00475
00476
00477 template <class F>
00478 class CacheArcIterator {
00479 public:
00480 typedef typename F::Arc Arc;
00481 typedef typename F::State State;
00482 typedef typename Arc::StateId StateId;
00483 typedef CacheBaseImpl<State> Impl;
00484
00485 CacheArcIterator(Impl *impl, StateId s) : i_(0) {
00486 state_ = impl->ExtendState(s);
00487 ++state_->ref_count;
00488 }
00489
00490
00491
00492 CacheArcIterator(const F &fst, StateId s) : i_(0) {
00493 state_ = fst.impl_->ExtendState(s);
00494 ++state_->ref_count;
00495 }
00496
00497 ~CacheArcIterator() { --state_->ref_count; }
00498
00499 bool Done() const { return i_ >= state_->arcs.size(); }
00500
00501 const Arc& Value() const { return state_->arcs[i_]; }
00502
00503 void Next() { ++i_; }
00504
00505 size_t Position() const { return i_; }
00506
00507 void Reset() { i_ = 0; }
00508
00509 void Seek(size_t a) { i_ = a; }
00510
00511 uint32 Flags() const {
00512 return kArcValueFlags;
00513 }
00514
00515 void SetFlags(uint32 flags, uint32 mask) {}
00516
00517 private:
00518 const State *state_;
00519 size_t i_;
00520
00521 DISALLOW_COPY_AND_ASSIGN(CacheArcIterator);
00522 };
00523
00524 }
00525
00526 #endif /// FST_LIB_CACHE_H__
00527