00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #ifndef FST_LIB_ACCUMULATOR_H__
00022 #define FST_LIB_ACCUMULATOR_H__
00023
00024 #include <algorithm>
00025 #include <tr1/unordered_map>
00026 using std::tr1::unordered_map;
00027 using std::tr1::unordered_multimap;
00028 #include <vector>
00029 using std::vector;
00030
00031 #include <fst/arcfilter.h>
00032 #include <fst/arcsort.h>
00033 #include <fst/dfs-visit.h>
00034 #include <fst/expanded-fst.h>
00035 #include <fst/replace.h>
00036
00037 namespace fst {
00038
00039
00040 template <class A>
00041 class DefaultAccumulator {
00042 public:
00043 typedef A Arc;
00044 typedef typename A::StateId StateId;
00045 typedef typename A::Weight Weight;
00046
00047 DefaultAccumulator() {}
00048
00049 DefaultAccumulator(const DefaultAccumulator<A> &acc) {}
00050
00051 void Init(const Fst<A>& fst, bool copy = false) {}
00052
00053 void SetState(StateId) {}
00054
00055 Weight Sum(Weight w, Weight v) {
00056 return Plus(w, v);
00057 }
00058
00059 template <class ArcIterator>
00060 Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
00061 ssize_t end) {
00062 Weight sum = w;
00063 aiter->Seek(begin);
00064 for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
00065 sum = Plus(sum, aiter->Value().weight);
00066 return sum;
00067 }
00068
00069 private:
00070 void operator=(const DefaultAccumulator<A> &);
00071 };
00072
00073
00074
00075
00076
00077 template <class A>
00078 class LogAccumulator {
00079 public:
00080 typedef A Arc;
00081 typedef typename A::StateId StateId;
00082 typedef typename A::Weight Weight;
00083
00084 LogAccumulator() {}
00085
00086 LogAccumulator(const LogAccumulator<A> &acc) {}
00087
00088 void Init(const Fst<A>& fst, bool copy = false) {}
00089
00090 void SetState(StateId) {}
00091
00092 Weight Sum(Weight w, Weight v) {
00093 return LogPlus(w, v);
00094 }
00095
00096 template <class ArcIterator>
00097 Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
00098 ssize_t end) {
00099 Weight sum = w;
00100 aiter->Seek(begin);
00101 for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
00102 sum = LogPlus(sum, aiter->Value().weight);
00103 return sum;
00104 }
00105
00106 private:
00107 double LogPosExp(double x) { return log(1.0F + exp(-x)); }
00108
00109 Weight LogPlus(Weight w, Weight v) {
00110 float f1 = w.Value();
00111 float f2 = v.Value();
00112 if (f1 > f2)
00113 return Weight(f2 - LogPosExp(f1 - f2));
00114 else
00115 return Weight(f1 - LogPosExp(f2 - f1));
00116 }
00117
00118 void operator=(const LogAccumulator<A> &);
00119 };
00120
00121
00122 class FastLogAccumulatorData {
00123 public:
00124 FastLogAccumulatorData() {}
00125
00126 vector<double> *Weights() { return &weights_; }
00127 vector<ssize_t> *WeightPositions() { return &weight_positions_; }
00128 double *WeightEnd() { return &(weights_[weights_.size() - 1]); };
00129 int RefCount() const { return ref_count_.count(); }
00130 int IncrRefCount() { return ref_count_.Incr(); }
00131 int DecrRefCount() { return ref_count_.Decr(); }
00132
00133 private:
00134
00135
00136
00137 vector<double> weights_;
00138
00139
00140
00141 vector<ssize_t> weight_positions_;
00142 RefCounter ref_count_;
00143
00144 DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData);
00145 };
00146
00147
00148
00149
00150
00151
00152 template <class A>
00153 class FastLogAccumulator {
00154 public:
00155 typedef A Arc;
00156 typedef typename A::StateId StateId;
00157 typedef typename A::Weight Weight;
00158
00159 explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
00160 : arc_limit_(arc_limit),
00161 arc_period_(arc_period),
00162 data_(new FastLogAccumulatorData()) {
00163 }
00164
00165 FastLogAccumulator(const FastLogAccumulator<A> &acc)
00166 : arc_limit_(acc.arc_limit_),
00167 arc_period_(acc.arc_period_),
00168 data_(acc.data_) {
00169 data_->IncrRefCount();
00170 }
00171
00172 ~FastLogAccumulator() {
00173 if (!data_->DecrRefCount())
00174 delete data_;
00175 }
00176
00177 void SetState(StateId s) {
00178 vector<double> &weights = *data_->Weights();
00179 vector<ssize_t> &weight_positions = *data_->WeightPositions();
00180
00181 CHECK(weight_positions.size() > s);
00182
00183 ssize_t pos = weight_positions[s];
00184 if (pos >= 0)
00185 state_weights_ = &(weights[pos]);
00186 else
00187 state_weights_ = 0;
00188 }
00189
00190 Weight Sum(Weight w, Weight v) {
00191 return LogPlus(w, v);
00192 }
00193
00194 template <class ArcIterator>
00195 Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
00196 ssize_t end) {
00197 Weight sum = w;
00198
00199 ssize_t index_begin = -1, index_end = -1;
00200 ssize_t stored_begin = end, stored_end = end;
00201 if (state_weights_ != 0) {
00202 index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0;
00203 index_end = end / arc_period_;
00204 stored_begin = index_begin * arc_period_;
00205 stored_end = index_end * arc_period_;
00206 }
00207
00208 if (begin < stored_begin) {
00209 ssize_t pos_end = min(stored_begin, end);
00210 aiter->Seek(begin);
00211 for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos)
00212 sum = LogPlus(sum, aiter->Value().weight);
00213 }
00214
00215 if (stored_begin < stored_end) {
00216 sum = LogPlus(sum, LogMinus(state_weights_[index_end],
00217 state_weights_[index_begin]));
00218 }
00219
00220 if (stored_end < end) {
00221 ssize_t pos_start = max(stored_begin, stored_end);
00222 aiter->Seek(pos_start);
00223 for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos)
00224 sum = LogPlus(sum, aiter->Value().weight);
00225 }
00226 return sum;
00227 }
00228
00229 template <class F>
00230 void Init(const F &fst, bool copy = false) {
00231 if (copy)
00232 return;
00233 vector<double> &weights = *data_->Weights();
00234 vector<ssize_t> &weight_positions = *data_->WeightPositions();
00235 CHECK(weights.empty());
00236 CHECK_GE(arc_limit_, arc_period_);
00237 weight_positions.reserve(CountStates(fst));
00238
00239 ssize_t weight_position = 0;
00240 for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
00241 StateId s = siter.Value();
00242 if (fst.NumArcs(s) >= arc_limit_) {
00243 double sum = FloatLimits<double>::kPosInfinity;
00244 weight_positions.push_back(weight_position);
00245 weights.push_back(sum);
00246 ++weight_position;
00247 ssize_t narcs = 0;
00248 for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) {
00249 const A &arc = aiter.Value();
00250 sum = LogPlus(sum, arc.weight);
00251
00252 if (++narcs % arc_period_ == 0) {
00253 weights.push_back(sum);
00254 ++weight_position;
00255 }
00256 }
00257 } else {
00258 weight_positions.push_back(-1);
00259 }
00260 }
00261 }
00262
00263 private:
00264 double LogPosExp(double x) {
00265 return x == FloatLimits<double>::kPosInfinity ?
00266 0.0 : log(1.0F + exp(-x));
00267 }
00268
00269 double LogMinusExp(double x) {
00270 return x == FloatLimits<double>::kPosInfinity ?
00271 0.0 : log(1.0F - exp(-x));
00272 }
00273
00274 Weight LogPlus(Weight w, Weight v) {
00275 float f1 = w.Value();
00276 float f2 = v.Value();
00277 if (f1 > f2)
00278 return Weight(f2 - LogPosExp(f1 - f2));
00279 else
00280 return Weight(f1 - LogPosExp(f2 - f1));
00281 }
00282
00283 double LogPlus(double f1, Weight v) {
00284 float f2 = v.Value();
00285 if (f1 == FloatLimits<double>::kPosInfinity)
00286 return f2;
00287 else if (f1 > f2)
00288 return f2 - LogPosExp(f1 - f2);
00289 else
00290 return f1 - LogPosExp(f2 - f1);
00291 }
00292
00293 Weight LogMinus(double f1, double f2) {
00294 CHECK_LT(f1, f2);
00295 if (f2 == FloatLimits<double>::kPosInfinity)
00296 return f1;
00297 else
00298 return Weight(f1 - LogMinusExp(f2 - f1));
00299 }
00300 ssize_t arc_limit_;
00301 ssize_t arc_period_;
00302 bool init_;
00303 FastLogAccumulatorData *data_;
00304 double *state_weights_;
00305
00306 void operator=(const FastLogAccumulator<A> &);
00307 };
00308
00309
00310
00311
00312 template <class A>
00313 class CacheLogAccumulatorData {
00314 public:
00315 typedef A Arc;
00316 typedef typename A::StateId StateId;
00317 typedef typename A::Weight Weight;
00318
00319 CacheLogAccumulatorData(bool gc, size_t gc_limit)
00320 : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
00321
00322 ~CacheLogAccumulatorData() {
00323 for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
00324 it != cache_.end();
00325 ++it)
00326 delete it->second.weights;
00327 }
00328
00329 bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
00330
00331 vector<double> *GetWeights(StateId s) {
00332 typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s);
00333 if (it != cache_.end()) {
00334 it->second.recent = true;
00335 return it->second.weights;
00336 } else {
00337 return 0;
00338 }
00339 }
00340
00341 void AddWeights(StateId s, vector<double> *weights) {
00342 if (cache_gc_ && cache_size_ >= cache_limit_)
00343 GC(false);
00344 cache_.insert(make_pair(s, CacheState(weights, true)));
00345 if (cache_gc_)
00346 cache_size_ += weights->capacity() * sizeof(double);
00347 }
00348
00349 int RefCount() const { return ref_count_.count(); }
00350 int IncrRefCount() { return ref_count_.Incr(); }
00351 int DecrRefCount() { return ref_count_.Decr(); }
00352
00353 private:
00354
00355 struct CacheState {
00356 vector<double>* weights;
00357 bool recent;
00358
00359 CacheState(vector<double> *w, bool r) : weights(w), recent(r) {}
00360 };
00361
00362
00363
00364
00365
00366 void GC(bool free_recent) {
00367 size_t cache_target = (2 * cache_limit_)/3 + 1;
00368 typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
00369 while (it != cache_.end() && cache_size_ > cache_target) {
00370 CacheState &cs = it->second;
00371 if (free_recent || !cs.recent) {
00372 cache_size_ -= cs.weights->capacity() * sizeof(double);
00373 delete cs.weights;
00374 cache_.erase(it++);
00375 } else {
00376 cs.recent = false;
00377 ++it;
00378 }
00379 }
00380 if (!free_recent && cache_size_ > cache_target)
00381 GC(true);
00382 }
00383
00384 unordered_map<StateId, CacheState> cache_;
00385 bool cache_gc_;
00386 size_t cache_limit_;
00387 size_t cache_size_;
00388 RefCounter ref_count_;
00389
00390 DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData);
00391 };
00392
00393
00394
00395
00396
00397
00398
00399 template <class A>
00400 class CacheLogAccumulator {
00401 public:
00402 typedef A Arc;
00403 typedef typename A::StateId StateId;
00404 typedef typename A::Weight Weight;
00405
00406 explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
00407 size_t gc_limit = 10 * 1024 * 1024)
00408 : arc_limit_(arc_limit), fst_(0), data_(
00409 new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId) {}
00410
00411 CacheLogAccumulator(const CacheLogAccumulator<A> &acc)
00412 : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0),
00413 data_(acc.data_), s_(kNoStateId) {
00414 data_->IncrRefCount();
00415 }
00416
00417 ~CacheLogAccumulator() {
00418 if (fst_)
00419 delete fst_;
00420 if (!data_->DecrRefCount())
00421 delete data_;
00422 }
00423
00424
00425 void Init(const Fst<A> &fst, bool copy = false) {
00426 if (copy)
00427 delete fst_;
00428 else
00429 CHECK(!fst_);
00430 fst_ = fst.Copy();
00431 }
00432
00433 void SetState(StateId s, int depth = 0) {
00434 if (s == s_)
00435 return;
00436 s_ = s;
00437
00438 if (data_->CacheDisabled()) {
00439 weights_ = 0;
00440 return;
00441 }
00442
00443 CHECK(fst_);
00444
00445 weights_ = data_->GetWeights(s);
00446 if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) {
00447 weights_ = new vector<double>;
00448 weights_->reserve(fst_->NumArcs(s) + 1);
00449 weights_->push_back(FloatLimits<double>::kPosInfinity);
00450 data_->AddWeights(s, weights_);
00451 }
00452 }
00453
00454 Weight Sum(Weight w, Weight v) {
00455 return LogPlus(w, v);
00456 }
00457
00458 template <class Iterator>
00459 Weight Sum(Weight w, Iterator *aiter, ssize_t begin,
00460 ssize_t end) {
00461 if (weights_ == 0) {
00462 Weight sum = w;
00463 aiter->Seek(begin);
00464 for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
00465 sum = LogPlus(sum, aiter->Value().weight);
00466 return sum;
00467 } else {
00468 if (weights_->size() <= end)
00469 for (aiter->Seek(weights_->size() - 1);
00470 weights_->size() <= end;
00471 aiter->Next())
00472 weights_->push_back(LogPlus(weights_->back(),
00473 aiter->Value().weight));
00474 return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin]));
00475 }
00476 }
00477
00478 template <class Iterator>
00479 size_t LowerBound(double w, Iterator *aiter) {
00480 if (weights_ != 0) {
00481 return lower_bound(weights_->begin() + 1,
00482 weights_->end(),
00483 w,
00484 std::greater<double>())
00485 - weights_->begin() - 1;
00486 } else {
00487 size_t n = 0;
00488 double x = FloatLimits<double>::kPosInfinity;
00489 for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
00490 x = LogPlus(x, aiter->Value().weight.Value());
00491 if (x < w) break;
00492 }
00493 return n;
00494 }
00495 }
00496
00497 private:
00498 double LogPosExp(double x) {
00499 return x == FloatLimits<double>::kPosInfinity ?
00500 0.0 : log(1.0F + exp(-x));
00501 }
00502
00503 double LogMinusExp(double x) {
00504 return x == FloatLimits<double>::kPosInfinity ?
00505 0.0 : log(1.0F - exp(-x));
00506 }
00507
00508 Weight LogPlus(Weight w, Weight v) {
00509 float f1 = w.Value();
00510 float f2 = v.Value();
00511 if (f1 > f2)
00512 return Weight(f2 - LogPosExp(f1 - f2));
00513 else
00514 return Weight(f1 - LogPosExp(f2 - f1));
00515 }
00516
00517 double LogPlus(double f1, Weight v) {
00518 float f2 = v.Value();
00519 if (f1 == FloatLimits<double>::kPosInfinity)
00520 return f2;
00521 else if (f1 > f2)
00522 return f2 - LogPosExp(f1 - f2);
00523 else
00524 return f1 - LogPosExp(f2 - f1);
00525 }
00526
00527 Weight LogMinus(double f1, double f2) {
00528 CHECK_LT(f1, f2);
00529 if (f2 == FloatLimits<double>::kPosInfinity)
00530 return f1;
00531 else
00532 return Weight(f1 - LogMinusExp(f2 - f1));
00533 }
00534
00535 ssize_t arc_limit_;
00536 vector<double> *weights_;
00537 const Fst<A>* fst_;
00538 CacheLogAccumulatorData<A> *data_;
00539 StateId s_;
00540
00541 void operator=(const CacheLogAccumulator<A> &);
00542 };
00543
00544
00545
00546 template <class Accumulator, class T>
00547 class ReplaceAccumulatorData {
00548 public:
00549 typedef typename Accumulator::Arc Arc;
00550 typedef typename Arc::StateId StateId;
00551 typedef typename Arc::Label Label;
00552 typedef T StateTable;
00553 typedef typename T::StateTuple StateTuple;
00554
00555 ReplaceAccumulatorData() : state_table_(0) {}
00556
00557 ReplaceAccumulatorData(const vector<Accumulator*> &accumulators)
00558 : state_table_(0), accumulators_(accumulators) {}
00559
00560 ~ReplaceAccumulatorData() {
00561 for (size_t i = 0; i < fst_array_.size(); ++i)
00562 delete fst_array_[i];
00563 for (size_t i = 0; i < accumulators_.size(); ++i)
00564 delete accumulators_[i];
00565 }
00566
00567 void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
00568 const StateTable *state_table) {
00569 state_table_ = state_table;
00570 accumulators_.resize(fst_tuples.size());
00571 for (size_t i = 0; i < accumulators_.size(); ++i) {
00572 if (!accumulators_[i])
00573 accumulators_[i] = new Accumulator;
00574 accumulators_[i]->Init(*(fst_tuples[i].second));
00575 fst_array_.push_back(fst_tuples[i].second->Copy());
00576 }
00577 }
00578
00579 const StateTuple &GetTuple(StateId s) const {
00580 return state_table_->Tuple(s);
00581 }
00582
00583 Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; }
00584
00585 const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; }
00586
00587 int RefCount() const { return ref_count_.count(); }
00588 int IncrRefCount() { return ref_count_.Incr(); }
00589 int DecrRefCount() { return ref_count_.Decr(); }
00590
00591 private:
00592 const T * state_table_;
00593 vector<Accumulator*> accumulators_;
00594 vector<const Fst<Arc>*> fst_array_;
00595 RefCounter ref_count_;
00596
00597 DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData);
00598 };
00599
00600
00601
00602
00603
00604 template <class Accumulator,
00605 class T = DefaultReplaceStateTable<typename Accumulator::Arc> >
00606 class ReplaceAccumulator {
00607 public:
00608 typedef typename Accumulator::Arc Arc;
00609 typedef typename Arc::StateId StateId;
00610 typedef typename Arc::Label Label;
00611 typedef typename Arc::Weight Weight;
00612 typedef T StateTable;
00613 typedef typename T::StateTuple StateTuple;
00614
00615 ReplaceAccumulator()
00616 : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()) {}
00617
00618 ReplaceAccumulator(const vector<Accumulator*> &accumulators)
00619 : init_(false),
00620 data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)) {}
00621
00622 ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc)
00623 : init_(acc.init_), data_(acc.data_) {
00624 if (!init_)
00625 LOG(FATAL) << "ReplaceAccumulator: can't copy unintialized accumulator";
00626 data_->IncrRefCount();
00627 }
00628
00629 ~ReplaceAccumulator() {
00630 if (!data_->DecrRefCount())
00631 delete data_;
00632 }
00633
00634
00635
00636 void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
00637 const StateTable *state_table) {
00638 init_ = true;
00639 data_->Init(fst_tuples, state_table);
00640 }
00641
00642 void SetState(StateId s) {
00643 CHECK(init_);
00644 StateTuple tuple = data_->GetTuple(s);
00645 fst_id_ = tuple.fst_id - 1;
00646 data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
00647 if ((tuple.prefix_id != 0) &&
00648 (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
00649 offset_ = 1;
00650 offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
00651 } else {
00652 offset_ = 0;
00653 offset_weight_ = Weight::Zero();
00654 }
00655 }
00656
00657 Weight Sum(Weight w, Weight v) {
00658 return data_->GetAccumulator(fst_id_)->Sum(w, v);
00659 }
00660
00661 template <class ArcIterator>
00662 Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
00663 ssize_t end) {
00664 Weight sum = begin == end ? Weight::Zero()
00665 : data_->GetAccumulator(fst_id_)->Sum(
00666 w, aiter, begin ? begin - offset_ : 0, end - offset_);
00667 if (begin == 0 && end != 0 && offset_ > 0)
00668 sum = Sum(offset_weight_, sum);
00669 return sum;
00670 }
00671
00672
00673 private:
00674 bool init_;
00675 ReplaceAccumulatorData<Accumulator, T> *data_;
00676 Label fst_id_;
00677 size_t offset_;
00678 Weight offset_weight_;
00679
00680 void operator=(const ReplaceAccumulator<Accumulator, T> &);
00681 };
00682
00683 }
00684
00685 #endif /// FST_LIB_ACCUMULATOR_H__
00686