00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #ifndef FST_LIB_REPLACE_H__
00023 #define FST_LIB_REPLACE_H__
00024
00025 #include <tr1/unordered_map>
00026 using std::tr1::unordered_map;
00027 using std::tr1::unordered_multimap;
00028 #include <string>
00029 #include <utility>
00030 using std::pair; using std::make_pair;
00031 #include <vector>
00032 using std::vector;
00033
00034 #include <fst/fst.h>
00035 #include <fst/cache.h>
00036 #include <fst/expanded-fst.h>
00037 #include <fst/matcher.h>
00038 #include <fst/replace-util.h>
00039 #include <fst/state-table.h>
00040 #include <fst/test-properties.h>
00041
00042 namespace fst {
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075 template <class S, class P>
00076 struct ReplaceStateTuple {
00077 typedef S StateId;
00078 typedef P PrefixId;
00079
00080 ReplaceStateTuple()
00081 : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
00082
00083 ReplaceStateTuple(PrefixId p, StateId f, StateId s)
00084 : prefix_id(p), fst_id(f), fst_state(s) {}
00085
00086 PrefixId prefix_id;
00087 StateId fst_id;
00088 StateId fst_state;
00089
00090 };
00091
00092
00093
00094 template <class S, class P>
00095 inline bool operator==(const ReplaceStateTuple<S, P>& x,
00096 const ReplaceStateTuple<S, P>& y) {
00097 return x.prefix_id == y.prefix_id &&
00098 x.fst_id == y.fst_id &&
00099 x.fst_state == y.fst_state;
00100 }
00101
00102
00103
00104
00105 template <class S, class P>
00106 class ReplaceRootSelector {
00107 public:
00108 bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
00109 return tuple.prefix_id == 0;
00110 }
00111 };
00112
00113
00114
00115
00116 template <class S, class P>
00117 class ReplaceFingerprint {
00118 public:
00119 ReplaceFingerprint(const vector<uint64> *size_array)
00120 : cumulative_size_array_(size_array) {}
00121
00122 uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
00123 return tuple.prefix_id * (cumulative_size_array_->back()) +
00124 cumulative_size_array_->at(tuple.fst_id - 1) +
00125 tuple.fst_state;
00126 }
00127
00128 private:
00129 const vector<uint64> *cumulative_size_array_;
00130 };
00131
00132
00133
00134
00135 template <class S, class P>
00136 class ReplaceFstStateFingerprint {
00137 public:
00138 uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
00139 return tuple.fst_state;
00140 }
00141 };
00142
00143
00144
00145
00146 template <typename S, typename P>
00147 class ReplaceHash {
00148 public:
00149 size_t operator()(const ReplaceStateTuple<S, P>& t) const {
00150 return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
00151 }
00152 private:
00153 static const size_t kPrime0;
00154 static const size_t kPrime1;
00155 };
00156
00157 template <typename S, typename P>
00158 const size_t ReplaceHash<S, P>::kPrime0 = 7853;
00159
00160 template <typename S, typename P>
00161 const size_t ReplaceHash<S, P>::kPrime1 = 7867;
00162
00163
00164
00165
00166
00167
00168 template <class A, class P = ssize_t>
00169 class VectorHashReplaceStateTable {
00170 public:
00171 typedef A Arc;
00172 typedef typename A::StateId StateId;
00173 typedef typename A::Label Label;
00174 typedef P PrefixId;
00175 typedef ReplaceStateTuple<StateId, P> StateTuple;
00176 typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
00177 ReplaceRootSelector<StateId, P>,
00178 ReplaceFstStateFingerprint<StateId, P>,
00179 ReplaceFingerprint<StateId, P> > StateTable;
00180
00181 VectorHashReplaceStateTable(
00182 const vector<pair<Label, const Fst<A>*> > &fst_tuples,
00183 Label root) : root_size_(0) {
00184 cumulative_size_array_.push_back(0);
00185 for (size_t i = 0; i < fst_tuples.size(); ++i) {
00186 if (fst_tuples[i].first == root) {
00187 root_size_ = CountStates(*(fst_tuples[i].second));
00188 cumulative_size_array_.push_back(cumulative_size_array_.back());
00189 } else {
00190 cumulative_size_array_.push_back(cumulative_size_array_.back() +
00191 CountStates(*(fst_tuples[i].second)));
00192 }
00193 }
00194 state_table_ = new StateTable(
00195 new ReplaceRootSelector<StateId, P>,
00196 new ReplaceFstStateFingerprint<StateId, P>,
00197 new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
00198 root_size_,
00199 root_size_ + cumulative_size_array_.back());
00200 }
00201
00202 VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
00203 : root_size_(table.root_size_),
00204 cumulative_size_array_(table.cumulative_size_array_) {
00205 state_table_ = new StateTable(
00206 new ReplaceRootSelector<StateId, P>,
00207 new ReplaceFstStateFingerprint<StateId, P>,
00208 new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
00209 root_size_,
00210 root_size_ + cumulative_size_array_.back());
00211 }
00212
00213 ~VectorHashReplaceStateTable() {
00214 delete state_table_;
00215 }
00216
00217 StateId FindState(const StateTuple &tuple) {
00218 return state_table_->FindState(tuple);
00219 }
00220
00221 const StateTuple &Tuple(StateId id) const {
00222 return state_table_->Tuple(id);
00223 }
00224
00225 private:
00226 StateId root_size_;
00227 vector<uint64> cumulative_size_array_;
00228 StateTable *state_table_;
00229 };
00230
00231
00232
00233
00234 template <class A, class P = ssize_t>
00235 class DefaultReplaceStateTable : public CompactHashStateTable<
00236 ReplaceStateTuple<typename A::StateId, P>,
00237 ReplaceHash<typename A::StateId, P> > {
00238 public:
00239 typedef A Arc;
00240 typedef typename A::StateId StateId;
00241 typedef typename A::Label Label;
00242 typedef P PrefixId;
00243 typedef ReplaceStateTuple<StateId, P> StateTuple;
00244 typedef CompactHashStateTable<StateTuple,
00245 ReplaceHash<StateId, PrefixId> > StateTable;
00246
00247 using StateTable::FindState;
00248 using StateTable::Tuple;
00249
00250 DefaultReplaceStateTable(
00251 const vector<pair<Label, const Fst<A>*> > &fst_tuples,
00252 Label root) {}
00253
00254 DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
00255 : StateTable() {}
00256 };
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266 template <class A, class T = DefaultReplaceStateTable<A> >
00267 struct ReplaceFstOptions : CacheOptions {
00268 int64 root;
00269 bool epsilon_on_replace;
00270 bool take_ownership;
00271 T* state_table;
00272
00273 ReplaceFstOptions(const CacheOptions &opts, int64 r)
00274 : CacheOptions(opts),
00275 root(r),
00276 epsilon_on_replace(false),
00277 take_ownership(false),
00278 state_table(0) {}
00279 explicit ReplaceFstOptions(int64 r)
00280 : root(r),
00281 epsilon_on_replace(false),
00282 take_ownership(false),
00283 state_table(0) {}
00284 ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
00285 : root(r),
00286 epsilon_on_replace(epsilon_replace_arc),
00287 take_ownership(false),
00288 state_table(0) {}
00289 ReplaceFstOptions()
00290 : root(kNoLabel),
00291 epsilon_on_replace(false),
00292 take_ownership(false),
00293 state_table(0) {}
00294 };
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304 template <class A, class T>
00305 class ReplaceFstImpl : public CacheImpl<A> {
00306 public:
00307 using FstImpl<A>::SetType;
00308 using FstImpl<A>::SetProperties;
00309 using FstImpl<A>::Properties;
00310 using FstImpl<A>::WriteHeader;
00311 using FstImpl<A>::SetInputSymbols;
00312 using FstImpl<A>::SetOutputSymbols;
00313 using FstImpl<A>::InputSymbols;
00314 using FstImpl<A>::OutputSymbols;
00315
00316 using CacheImpl<A>::AddArc;
00317 using CacheImpl<A>::HasArcs;
00318 using CacheImpl<A>::HasFinal;
00319 using CacheImpl<A>::HasStart;
00320 using CacheImpl<A>::SetArcs;
00321 using CacheImpl<A>::SetFinal;
00322 using CacheImpl<A>::SetStart;
00323
00324 typedef typename A::Label Label;
00325 typedef typename A::Weight Weight;
00326 typedef typename A::StateId StateId;
00327 typedef CacheState<A> State;
00328 typedef A Arc;
00329 typedef unordered_map<Label, Label> NonTerminalHash;
00330
00331 typedef T StateTable;
00332 typedef typename T::PrefixId PrefixId;
00333 typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
00334
00335
00336
00337 ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
00338 const ReplaceFstOptions<A, T> &opts)
00339 : CacheImpl<A>(opts),
00340 epsilon_on_replace_(opts.epsilon_on_replace),
00341 state_table_(opts.state_table ? opts.state_table :
00342 new StateTable(fst_tuples, opts.root)) {
00343
00344 SetType("replace");
00345
00346 if (fst_tuples.size() > 0) {
00347 SetInputSymbols(fst_tuples[0].second->InputSymbols());
00348 SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
00349 }
00350
00351 bool all_negative = true;
00352 bool dense_range = true;
00353
00354 for (size_t i = 0; i < fst_tuples.size(); ++i) {
00355 Label nonterminal = fst_tuples[i].first;
00356 if (nonterminal >= 0)
00357 all_negative = false;
00358 if (nonterminal > fst_tuples.size() || nonterminal <= 0)
00359 dense_range = false;
00360 }
00361
00362 vector<uint64> inprops;
00363 bool all_ilabel_sorted = true;
00364 bool all_olabel_sorted = true;
00365 bool all_non_empty = true;
00366 fst_array_.push_back(0);
00367 for (size_t i = 0; i < fst_tuples.size(); ++i) {
00368 Label label = fst_tuples[i].first;
00369 const Fst<A> *fst = fst_tuples[i].second;
00370 nonterminal_hash_[label] = fst_array_.size();
00371 fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
00372 if (fst->Start() == kNoStateId)
00373 all_non_empty = false;
00374 if(!fst->Properties(kILabelSorted, false))
00375 all_ilabel_sorted = false;
00376 if(!fst->Properties(kOLabelSorted, false))
00377 all_olabel_sorted = false;
00378 inprops.push_back(fst->Properties(kCopyProperties, false));
00379 if (i) {
00380 if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
00381 LOG(FATAL) << "ReplaceFstImpl: input symbols of Fst " << i
00382 << " does not match input symbols of base Fst (0'th fst)";
00383 }
00384 if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
00385 LOG(FATAL) << "ReplaceFstImpl: output symbols of Fst " << i
00386 << " does not match output symbols of base Fst "
00387 << "(0'th fst)";
00388 }
00389 }
00390 }
00391 Label nonterminal = nonterminal_hash_[opts.root];
00392 if ((nonterminal == 0) && (fst_array_.size() > 1)) {
00393 LOG(FATAL) << "ReplaceFstImpl: no Fst corresponding to root label '"
00394 << opts.root << "' in the input tuple vector";
00395 }
00396 root_ = (nonterminal > 0) ? nonterminal : 1;
00397
00398 SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_,
00399 all_non_empty));
00400
00401
00402
00403
00404
00405
00406 if (all_ilabel_sorted &&
00407 (!epsilon_on_replace_ || all_negative || dense_range))
00408 SetProperties(kILabelSorted, kILabelSorted);
00409
00410
00411
00412
00413
00414 if (all_olabel_sorted && (all_negative || dense_range))
00415 SetProperties(kOLabelSorted, kOLabelSorted);
00416
00417
00418 if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
00419 always_cache_ = false;
00420 else
00421 always_cache_ = true;
00422 VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
00423 << (always_cache_ ? "true" : "false");
00424 }
00425
00426 ReplaceFstImpl(const ReplaceFstImpl& impl)
00427 : CacheImpl<A>(impl),
00428 epsilon_on_replace_(impl.epsilon_on_replace_),
00429 always_cache_(impl.always_cache_),
00430 state_table_(new StateTable(*(impl.state_table_))),
00431 nonterminal_hash_(impl.nonterminal_hash_),
00432 root_(impl.root_) {
00433 SetType("replace");
00434 SetProperties(impl.Properties(), kCopyProperties);
00435 SetInputSymbols(impl.InputSymbols());
00436 SetOutputSymbols(impl.OutputSymbols());
00437 fst_array_.reserve(impl.fst_array_.size());
00438 fst_array_.push_back(0);
00439 for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
00440 fst_array_.push_back(impl.fst_array_[i]->Copy(true));
00441 }
00442 }
00443
00444 ~ReplaceFstImpl() {
00445 VLOG(2) << "~ReplaceFstImpl: gc = "
00446 << (CacheImpl<A>::GetCacheGc() ? "true" : "false")
00447 << ", gc_size = " << CacheImpl<A>::GetCacheSize()
00448 << ", gc_limit = " << CacheImpl<A>::GetCacheLimit();
00449
00450 delete state_table_;
00451 for (size_t i = 1; i < fst_array_.size(); ++i) {
00452 delete fst_array_[i];
00453 }
00454 }
00455
00456
00457
00458
00459 bool CyclicDependencies() const {
00460 ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_);
00461 return replace_util.CyclicDependencies();
00462 }
00463
00464
00465 StateId Start() {
00466 if (!HasStart()) {
00467 if (fst_array_.size() == 1) {
00468 SetStart(kNoStateId);
00469 return kNoStateId;
00470 } else {
00471 const Fst<A>* fst = fst_array_[root_];
00472 StateId fst_start = fst->Start();
00473 if (fst_start == kNoStateId)
00474 return kNoStateId;
00475
00476 PrefixId prefix = GetPrefixId(StackPrefix());
00477 StateId start = state_table_->FindState(
00478 StateTuple(prefix, root_, fst_start));
00479 SetStart(start);
00480 return start;
00481 }
00482 } else {
00483 return CacheImpl<A>::Start();
00484 }
00485 }
00486
00487
00488 Weight Final(StateId s) {
00489 if (!HasFinal(s)) {
00490 const StateTuple& tuple = state_table_->Tuple(s);
00491 const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
00492 const Fst<A>* fst = fst_array_[tuple.fst_id];
00493 StateId fst_state = tuple.fst_state;
00494
00495 if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
00496 SetFinal(s, fst->Final(fst_state));
00497 else
00498 SetFinal(s, Weight::Zero());
00499 }
00500 return CacheImpl<A>::Final(s);
00501 }
00502
00503 size_t NumArcs(StateId s) {
00504 if (HasArcs(s)) {
00505 return CacheImpl<A>::NumArcs(s);
00506 } else if (always_cache_) {
00507 Expand(s);
00508 return CacheImpl<A>::NumArcs(s);
00509 } else {
00510 StateTuple tuple = state_table_->Tuple(s);
00511 if (tuple.fst_state == kNoStateId)
00512 return 0;
00513
00514 const Fst<A>* fst = fst_array_[tuple.fst_id];
00515 size_t num_arcs = fst->NumArcs(tuple.fst_state);
00516 if (ComputeFinalArc(tuple, 0))
00517 num_arcs++;
00518
00519 return num_arcs;
00520 }
00521 }
00522
00523
00524 bool IsNonTerminal(Label l) const {
00525
00526
00527
00528
00529
00530
00531 typename NonTerminalHash::const_iterator it =
00532 nonterminal_hash_.find(l);
00533 return it != nonterminal_hash_.end();
00534 }
00535
00536 size_t NumInputEpsilons(StateId s) {
00537 if (HasArcs(s)) {
00538
00539 return CacheImpl<A>::NumInputEpsilons(s);
00540 } else if (always_cache_ || !Properties(kILabelSorted)) {
00541
00542
00543
00544 Expand(s);
00545 return CacheImpl<A>::NumInputEpsilons(s);
00546 } else {
00547
00548 StateTuple tuple = state_table_->Tuple(s);
00549 if (tuple.fst_state == kNoStateId)
00550 return 0;
00551 const Fst<A>* fst = fst_array_[tuple.fst_id];
00552 size_t num = 0;
00553 if (!epsilon_on_replace_) {
00554
00555
00556 fst->NumInputEpsilons(tuple.fst_state);
00557 } else {
00558
00559
00560 ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
00561 for (; !aiter.Done() &&
00562 ((aiter.Value().ilabel == 0) ||
00563 IsNonTerminal(aiter.Value().olabel));
00564 aiter.Next())
00565 ++num;
00566 }
00567 if (ComputeFinalArc(tuple, 0))
00568 num++;
00569 return num;
00570 }
00571 }
00572
00573 size_t NumOutputEpsilons(StateId s) {
00574 if (HasArcs(s)) {
00575
00576 return CacheImpl<A>::NumOutputEpsilons(s);
00577 } else if(always_cache_ || !Properties(kOLabelSorted)) {
00578
00579
00580
00581 Expand(s);
00582 return CacheImpl<A>::NumOutputEpsilons(s);
00583 } else {
00584
00585 StateTuple tuple = state_table_->Tuple(s);
00586 if (tuple.fst_state == kNoStateId)
00587 return 0;
00588 const Fst<A>* fst = fst_array_[tuple.fst_id];
00589 size_t num = 0;
00590 ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
00591 for (; !aiter.Done() &&
00592 ((aiter.Value().olabel == 0) ||
00593 IsNonTerminal(aiter.Value().olabel));
00594 aiter.Next())
00595 ++num;
00596 if (ComputeFinalArc(tuple, 0))
00597 num++;
00598 return num;
00599 }
00600 }
00601
00602
00603
00604 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
00605 if (!HasArcs(s))
00606 Expand(s);
00607 CacheImpl<A>::InitArcIterator(s, data);
00608
00609
00610
00611 }
00612
00613
00614
00615 void Expand(StateId s) {
00616 StateTuple tuple = state_table_->Tuple(s);
00617
00618
00619 if (tuple.fst_state == kNoStateId) {
00620 SetArcs(s);
00621 return;
00622 }
00623
00624 ArcIterator< Fst<A> > aiter(
00625 *(fst_array_[tuple.fst_id]), tuple.fst_state);
00626 Arc arc;
00627
00628
00629 if (ComputeFinalArc(tuple, &arc))
00630 AddArc(s, arc);
00631
00632
00633 for (;!aiter.Done(); aiter.Next()) {
00634 if (ComputeArc(tuple, aiter.Value(), &arc))
00635 AddArc(s, arc);
00636 }
00637
00638 SetArcs(s);
00639 }
00640
00641 void Expand(StateId s, const StateTuple &tuple,
00642 const ArcIteratorData<A> &data) {
00643
00644 if (tuple.fst_state == kNoStateId) {
00645 SetArcs(s);
00646 return;
00647 }
00648
00649 ArcIterator< Fst<A> > aiter(data);
00650 Arc arc;
00651
00652
00653 if (ComputeFinalArc(tuple, &arc))
00654 AddArc(s, arc);
00655
00656
00657 for (; !aiter.Done(); aiter.Next()) {
00658 if (ComputeArc(tuple, aiter.Value(), &arc))
00659 AddArc(s, arc);
00660 }
00661
00662 SetArcs(s);
00663 }
00664
00665
00666
00667 bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
00668 uint32 flags = kArcValueFlags) {
00669 const Fst<A>* fst = fst_array_[tuple.fst_id];
00670 StateId fst_state = tuple.fst_state;
00671 if (fst_state == kNoStateId)
00672 return false;
00673
00674
00675 const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
00676 if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
00677 if (arcp) {
00678 arcp->ilabel = 0;
00679 arcp->olabel = 0;
00680 if (flags & kArcNextStateValue) {
00681 PrefixId prefix_id = PopPrefix(stack);
00682 const PrefixTuple& top = stack.Top();
00683 arcp->nextstate = state_table_->FindState(
00684 StateTuple(prefix_id, top.fst_id, top.nextstate));
00685 }
00686 if (flags & kArcWeightValue)
00687 arcp->weight = fst->Final(fst_state);
00688 }
00689 return true;
00690 } else {
00691 return false;
00692 }
00693 }
00694
00695
00696
00697
00698 bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
00699 uint32 flags = kArcValueFlags) {
00700 if (!epsilon_on_replace_ &&
00701 (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
00702 *arcp = arc;
00703 return true;
00704 }
00705
00706 if (arc.olabel == 0) {
00707 StateId nextstate = flags & kArcNextStateValue
00708 ? state_table_->FindState(
00709 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
00710 : kNoStateId;
00711 *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
00712 } else {
00713
00714 typename NonTerminalHash::const_iterator it =
00715 nonterminal_hash_.find(arc.olabel);
00716 if (it != nonterminal_hash_.end()) {
00717 Label nonterminal = it->second;
00718 const Fst<A>* nt_fst = fst_array_[nonterminal];
00719 PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
00720 tuple.fst_id, arc.nextstate);
00721
00722
00723
00724 StateId nt_start = nt_fst->Start();
00725 if (nt_start != kNoStateId) {
00726 StateId nt_nextstate = flags & kArcNextStateValue
00727 ? state_table_->FindState(
00728 StateTuple(nt_prefix, nonterminal, nt_start))
00729 : kNoStateId;
00730 Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel;
00731 *arcp = A(ilabel, 0, arc.weight, nt_nextstate);
00732 } else {
00733 return false;
00734 }
00735 } else {
00736 StateId nextstate = flags & kArcNextStateValue
00737 ? state_table_->FindState(
00738 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
00739 : kNoStateId;
00740 *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
00741 }
00742 }
00743 return true;
00744 }
00745
00746
00747 uint32 ArcIteratorFlags() const {
00748 uint32 flags = kArcValueFlags;
00749 if (!always_cache_)
00750 flags |= kArcNoCache;
00751 return flags;
00752 }
00753
00754 T* GetStateTable() const {
00755 return state_table_;
00756 }
00757
00758 const Fst<A>* GetFst(Label fst_id) const {
00759 return fst_array_[fst_id];
00760 }
00761
00762 bool EpsilonOnReplace() const { return epsilon_on_replace_; }
00763
00764
00765 private:
00766 static const size_t kPrime0;
00767
00768
00769
00770 struct PrefixTuple {
00771 PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
00772
00773 Label fst_id;
00774 StateId nextstate;
00775 };
00776
00777
00778
00779 class StackPrefix {
00780 public:
00781 StackPrefix() {}
00782
00783
00784 StackPrefix(const StackPrefix& x) :
00785 prefix_(x.prefix_) {
00786 }
00787
00788 void Push(StateId fst_id, StateId nextstate) {
00789 prefix_.push_back(PrefixTuple(fst_id, nextstate));
00790 }
00791
00792 void Pop() {
00793 prefix_.pop_back();
00794 }
00795
00796 const PrefixTuple& Top() const {
00797 return prefix_[prefix_.size()-1];
00798 }
00799
00800 size_t Depth() const {
00801 return prefix_.size();
00802 }
00803
00804 public:
00805 vector<PrefixTuple> prefix_;
00806 };
00807
00808
00809
00810
00811 class StackPrefixEqual {
00812 public:
00813 bool operator()(const StackPrefix& x, const StackPrefix& y) const {
00814 if (x.prefix_.size() != y.prefix_.size()) return false;
00815 for (size_t i = 0; i < x.prefix_.size(); ++i) {
00816 if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
00817 x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
00818 }
00819 return true;
00820 }
00821 };
00822
00823
00824
00825
00826 class StackPrefixKey {
00827 public:
00828 size_t operator()(const StackPrefix& x) const {
00829 size_t sum = 0;
00830 for (size_t i = 0; i < x.prefix_.size(); ++i) {
00831 sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
00832 }
00833 return sum;
00834 }
00835 };
00836
00837 typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual>
00838 StackPrefixHash;
00839
00840
00841 private:
00842
00843 PrefixId GetPrefixId(const StackPrefix& prefix) {
00844 typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
00845 if (it == prefix_hash_.end()) {
00846 PrefixId prefix_id = stackprefix_array_.size();
00847 stackprefix_array_.push_back(prefix);
00848 prefix_hash_[prefix] = prefix_id;
00849 return prefix_id;
00850 } else {
00851 return it->second;
00852 }
00853 }
00854
00855
00856 PrefixId PopPrefix(StackPrefix prefix) {
00857 prefix.Pop();
00858 return GetPrefixId(prefix);
00859 }
00860
00861
00862 PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
00863 prefix.Push(fst_id, nextstate);
00864 return GetPrefixId(prefix);
00865 }
00866
00867
00868
00869 private:
00870
00871 bool epsilon_on_replace_;
00872 bool always_cache_;
00873
00874
00875 StateTable *state_table_;
00876
00877
00878
00879 StackPrefixHash prefix_hash_;
00880 vector<StackPrefix> stackprefix_array_;
00881
00882 NonTerminalHash nonterminal_hash_;
00883 vector<const Fst<A>*> fst_array_;
00884 Label root_;
00885
00886 void operator=(const ReplaceFstImpl<A, T> &);
00887 };
00888
00889
00890 template <class A, class T>
00891 const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853;
00892
00893
00894
00895
00896
00897
00898
00899
00900
00901
00902
00903
00904
00905
00906
00907
00908
00909
00910
00911
00912
00913
00914
00915
00916
00917
00918
00919
00920
00921
00922
00923
00924
00925
00926
00927
00928
00929
00930
00931 template <class A, class T = DefaultReplaceStateTable<A> >
00932 class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > {
00933 public:
00934 friend class ArcIterator< ReplaceFst<A, T> >;
00935 friend class StateIterator< ReplaceFst<A, T> >;
00936
00937 typedef A Arc;
00938 typedef typename A::Label Label;
00939 typedef typename A::Weight Weight;
00940 typedef typename A::StateId StateId;
00941 typedef CacheState<A> State;
00942 typedef ReplaceFstImpl<A, T> Impl;
00943
00944 using ImplToFst<Impl>::Properties;
00945
00946 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
00947 Label root)
00948 : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {}
00949
00950 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
00951 const ReplaceFstOptions<A, T> &opts)
00952 : ImplToFst<Impl>(new Impl(fst_array, opts)) {}
00953
00954
00955 ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false)
00956 : ImplToFst<Impl>(fst, safe) {}
00957
00958
00959 virtual ReplaceFst<A, T> *Copy(bool safe = false) const {
00960 return new ReplaceFst<A, T>(*this, safe);
00961 }
00962
00963 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
00964
00965 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
00966 GetImpl()->InitArcIterator(s, data);
00967 }
00968
00969 virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
00970 if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
00971 ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
00972 (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false))))
00973 return new SortedMatcher< ReplaceFst<A> >(*this, match_type);
00974 else
00975 return 0;
00976 }
00977
00978 bool CyclicDependencies() const {
00979 return GetImpl()->CyclicDependencies();
00980 }
00981
00982 private:
00983
00984 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
00985
00986 void operator=(const ReplaceFst<A> &fst);
00987 };
00988
00989
00990
00991 template<class A, class T>
00992 class StateIterator< ReplaceFst<A, T> >
00993 : public CacheStateIterator< ReplaceFst<A, T> > {
00994 public:
00995 explicit StateIterator(const ReplaceFst<A, T> &fst)
00996 : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {}
00997
00998 private:
00999 DISALLOW_COPY_AND_ASSIGN(StateIterator);
01000 };
01001
01002
01003
01004
01005
01006
01007
01008
01009
01010
01011
01012
01013
01014
01015
01016
01017
01018
01019
01020
01021
01022
01023 template <class A, class T>
01024 class ArcIterator< ReplaceFst<A, T> > {
01025 public:
01026 typedef A Arc;
01027 typedef typename A::StateId StateId;
01028
01029 ArcIterator(const ReplaceFst<A, T> &fst, StateId s)
01030 : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0),
01031 data_flags_(0), final_flags_(0) {
01032 cache_data_.ref_count = 0;
01033 local_data_.ref_count = 0;
01034
01035
01036 if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
01037 !(fst_.GetImpl()->HasArcs(state_)))
01038 fst_.GetImpl()->Expand(state_);
01039
01040
01041 if (fst_.GetImpl()->HasArcs(state_)) {
01042 (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_,
01043 &cache_data_);
01044 num_arcs_ = cache_data_.narcs;
01045 arcs_ = cache_data_.arcs;
01046 data_flags_ = kArcValueFlags;
01047 } else {
01048 tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
01049 if (tuple_.fst_state == kNoStateId) {
01050 num_arcs_ = 0;
01051 } else {
01052
01053
01054
01055
01056 const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
01057 fst->InitArcIterator(tuple_.fst_state, &local_data_);
01058
01059 arcs_ = local_data_.arcs;
01060
01061
01062 bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
01063 tuple_,
01064 &final_arc_,
01065 kArcValueFlags & ~kArcNextStateValue);
01066
01067 final_flags_ = kArcValueFlags & ~kArcNextStateValue;
01068
01069 num_arcs_ = local_data_.narcs;
01070 if (has_final_arc)
01071 ++num_arcs_;
01072
01073
01074 offset_ = num_arcs_ - local_data_.narcs;
01075
01076
01077 data_flags_ = 0;
01078 }
01079 }
01080 }
01081
01082 ~ArcIterator() {
01083 if (cache_data_.ref_count)
01084 --(*cache_data_.ref_count);
01085 if (local_data_.ref_count)
01086 --(*local_data_.ref_count);
01087 }
01088
01089 void ExpandAndCache() const {
01090
01091
01092
01093
01094
01095 fst_.InitArcIterator(state_, &cache_data_);
01096 arcs_ = cache_data_.arcs;
01097 data_flags_ = kArcValueFlags;
01098 offset_ = 0;
01099
01100 }
01101
01102 void Init() {
01103 if (flags_ & kArcNoCache) {
01104
01105 arcs_ = local_data_.arcs;
01106
01107 data_flags_ = kArcWeightValue;
01108 if (!fst_.GetImpl()->EpsilonOnReplace())
01109 data_flags_ |= kArcILabelValue;
01110
01111
01112 offset_ = num_arcs_ - local_data_.narcs;
01113 } else {
01114 ExpandAndCache();
01115 }
01116 }
01117
01118 bool Done() const { return pos_ >= num_arcs_; }
01119
01120 const A& Value() const {
01121
01122 if (!data_flags_) {
01123 CHECK(!(flags_ & kArcNoCache));
01124 ExpandAndCache();
01125 }
01126
01127 if (pos_ - offset_ >= 0) {
01128 const A& arc = arcs_[pos_ - offset_];
01129 if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
01130
01131
01132 return arc;
01133 } else {
01134
01135 fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
01136 return arc_;
01137 }
01138 } else {
01139 if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
01140
01141
01142
01143 fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
01144 flags_ & kArcValueFlags);
01145 final_flags_ = flags_ & kArcValueFlags;
01146 }
01147 return final_arc_;
01148 }
01149 }
01150
01151 void Next() { ++pos_; }
01152
01153 size_t Position() const { return pos_; }
01154
01155 void Reset() { pos_ = 0; }
01156
01157 void Seek(size_t pos) { pos_ = pos; }
01158
01159 uint32 Flags() const { return flags_; }
01160
01161 void SetFlags(uint32 f, uint32 mask) {
01162
01163
01164 flags_ &= ~mask;
01165 flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
01166
01167
01168
01169 if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
01170 if (!fst_.GetImpl()->HasArcs(state_))
01171 data_flags_ = 0;
01172 }
01173
01174
01175 if ((f & kArcNoCache) && (!data_flags_))
01176 Init();
01177 }
01178
01179 private:
01180 const ReplaceFst<A, T> &fst_;
01181 StateId state_;
01182 mutable typename T::StateTuple tuple_;
01183
01184 ssize_t pos_;
01185 mutable ssize_t offset_;
01186 ssize_t num_arcs_;
01187 uint32 flags_;
01188 mutable Arc arc_;
01189
01190 mutable ArcIteratorData<Arc> cache_data_;
01191 mutable ArcIteratorData<Arc> local_data_;
01192
01193 mutable const A* arcs_;
01194 mutable uint32 data_flags_;
01195 mutable Arc final_arc_;
01196 mutable uint32 final_flags_;
01197
01198 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
01199 };
01200
01201
01202 template <class A, class T> inline
01203 void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const {
01204 data->base = new StateIterator< ReplaceFst<A, T> >(*this);
01205 }
01206
01207 typedef ReplaceFst<StdArc> StdReplaceFst;
01208
01209
01210
01211
01212
01213
01214
01215
01216
01217
01218
01219
01220
01221
01222
01223 template<class Arc>
01224 void Replace(const vector<pair<typename Arc::Label,
01225 const Fst<Arc>* > >& ifst_array,
01226 MutableFst<Arc> *ofst, typename Arc::Label root,
01227 bool epsilon_on_replace) {
01228 ReplaceFstOptions<Arc> opts(root, epsilon_on_replace);
01229 opts.gc_limit = 0;
01230 *ofst = ReplaceFst<Arc>(ifst_array, opts);
01231 }
01232
01233 template<class Arc>
01234 void Replace(const vector<pair<typename Arc::Label,
01235 const Fst<Arc>* > >& ifst_array,
01236 MutableFst<Arc> *ofst, typename Arc::Label root) {
01237 Replace(ifst_array, ofst, root, false);
01238 }
01239
01240 }
01241
01242 #endif /// FST_LIB_REPLACE_H__
01243